未验证 提交 31f81685 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Correct lstm qat test (#51499)

上级 6ba0507d
...@@ -37,7 +37,7 @@ void ComputePropagateScalesMkldnnPass::GetTensorFromVector( ...@@ -37,7 +37,7 @@ void ComputePropagateScalesMkldnnPass::GetTensorFromVector(
void ComputePropagateScalesMkldnnPass::GetQuantInfo( void ComputePropagateScalesMkldnnPass::GetQuantInfo(
ir::Graph* graph, StringPairMap* var_quant_scales) const { ir::Graph* graph, StringPairMap* var_quant_scales) const {
std::unordered_map<std::string, std::vector<float>> info_map{}; std::unordered_map<std::string, std::vector<float>> info_map{};
GetInfoFromTheFirstOp(graph, "has_quant_info", "var_quant_scales", &info_map); GetInfoFromTheTmpOp(graph, "has_quant_info", "var_quant_scales", &info_map);
for (auto iter = info_map.begin(); iter != info_map.end(); iter++) { for (auto iter = info_map.begin(); iter != info_map.end(); iter++) {
phi::DenseTensor tensor; phi::DenseTensor tensor;
...@@ -510,9 +510,9 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const { ...@@ -510,9 +510,9 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const {
UpdateReluOutputScales(graph, &var_quant_scales); UpdateReluOutputScales(graph, &var_quant_scales);
PropagateScales(graph, &var_quant_scales, scale_immutable_ops); PropagateScales(graph, &var_quant_scales, scale_immutable_ops);
// save var_quant_scales in the first op's attr // save var_quant_scales in the temporary save op's attr
// for cpu_quantize_pass // for cpu_quantize_pass
SaveInfoInTheFirstOp( SaveInfoInTheTmpOp(
graph, "has_quant_info", "var_quant_scales", var_quant_scales); graph, "has_quant_info", "var_quant_scales", var_quant_scales);
} }
......
...@@ -435,7 +435,7 @@ bool CPUQuantizePass::IsOpQuantized(const Node* node) const { ...@@ -435,7 +435,7 @@ bool CPUQuantizePass::IsOpQuantized(const Node* node) const {
} }
void CPUQuantizePass::GetQuantInfo(Graph* graph) const { void CPUQuantizePass::GetQuantInfo(Graph* graph) const {
GetInfoFromTheFirstOp( GetInfoFromTheTmpOp(
graph, "has_quant_info", "var_quant_scales", var_quant_scales_); graph, "has_quant_info", "var_quant_scales", var_quant_scales_);
} }
...@@ -1250,6 +1250,10 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const { ...@@ -1250,6 +1250,10 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
bool is_x_unsigned{false}; bool is_x_unsigned{false};
auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned); auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);
// In the QAT process scales are prepared for only int8 data type,
// lstm scales should behave as input is int8 to get correct accuracy
is_x_unsigned = false;
double input_x_shift{128.}; double input_x_shift{128.};
if (is_x_unsigned) input_x_shift = 0.; if (is_x_unsigned) input_x_shift = 0.;
......
...@@ -25,7 +25,7 @@ namespace ir { ...@@ -25,7 +25,7 @@ namespace ir {
using StringPairMap = using StringPairMap =
std::unordered_map<std::string, std::pair<bool, phi::DenseTensor>>; std::unordered_map<std::string, std::pair<bool, phi::DenseTensor>>;
static void SaveInfoInTheFirstOp( static void SaveInfoInTheTmpOp(
ir::Graph* graph, ir::Graph* graph,
const std::string& flag, const std::string& flag,
const std::string& key_suffix, const std::string& key_suffix,
...@@ -33,48 +33,39 @@ static void SaveInfoInTheFirstOp( ...@@ -33,48 +33,39 @@ static void SaveInfoInTheFirstOp(
VLOG(3) << "save variables in the first op's attr"; VLOG(3) << "save variables in the first op's attr";
const std::string suffix = "_" + key_suffix + "_" + flag; const std::string suffix = "_" + key_suffix + "_" + flag;
for (auto* op_node : OpDesc op_desc;
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) { op_desc.SetType("save");
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || auto* op_node = graph->CreateOpNode(&op_desc);
op_node->Op()->Type() == "fetch" ||
op_node->Op()->Type() == "fill_constant") op_node->Op()->SetAttr(flag, true);
continue; for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
op_node->Op()->SetAttr(iter->first + suffix, iter->second);
op_node->Op()->SetAttr(flag, true);
for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
op_node->Op()->SetAttr(iter->first + suffix, iter->second);
}
break;
} }
} }
static void SaveInfoInTheFirstOp(ir::Graph* graph, static void SaveInfoInTheTmpOp(ir::Graph* graph,
const std::string& flag, const std::string& flag,
const std::string& key_suffix, const std::string& key_suffix,
const StringPairMap& info_map) { const StringPairMap& info_map) {
VLOG(3) << "save variables in the first op's attr"; VLOG(3) << "save variables in the first op's attr";
const std::string suffix = "_" + key_suffix + "_" + flag; const std::string suffix = "_" + key_suffix + "_" + flag;
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) { OpDesc op_desc;
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || op_desc.SetType("save");
op_node->Op()->Type() == "fetch" || auto* op_node = graph->CreateOpNode(&op_desc);
op_node->Op()->Type() == "fill_constant")
continue; op_node->Op()->SetAttr(flag, true);
for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
op_node->Op()->SetAttr(flag, true); auto* data = iter->second.second.data<float>();
for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) { std::vector<float> data_v(data, data + iter->second.second.numel());
auto* data = iter->second.second.data<float>(); op_node->Op()->SetAttr(iter->first + suffix + "_unsigned",
std::vector<float> data_v(data, data + iter->second.second.numel()); iter->second.first);
op_node->Op()->SetAttr(iter->first + suffix + "_unsigned", op_node->Op()->SetAttr(iter->first + suffix, data_v);
iter->second.first);
op_node->Op()->SetAttr(iter->first + suffix, data_v);
}
break;
} }
} }
static void GetInfoFromTheFirstOp( static void GetInfoFromTheTmpOp(
ir::Graph* graph, ir::Graph* graph,
const std::string& flag, const std::string& flag,
const std::string& key_suffix, const std::string& key_suffix,
...@@ -84,9 +75,7 @@ static void GetInfoFromTheFirstOp( ...@@ -84,9 +75,7 @@ static void GetInfoFromTheFirstOp(
const std::string suffix = "_" + key_suffix + "_" + flag; const std::string suffix = "_" + key_suffix + "_" + flag;
for (auto* op_node : for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) { ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || if (!op_node->IsOp() || op_node->Op()->Type() != "save") continue;
op_node->Op()->Type() == "fetch")
continue;
auto* op_desc = op_node->Op(); auto* op_desc = op_node->Op();
if (op_desc->GetAttrIfExists<bool>(flag)) { if (op_desc->GetAttrIfExists<bool>(flag)) {
...@@ -102,24 +91,23 @@ static void GetInfoFromTheFirstOp( ...@@ -102,24 +91,23 @@ static void GetInfoFromTheFirstOp(
op_desc->RemoveAttr(fake_name); op_desc->RemoveAttr(fake_name);
} }
} }
graph->RemoveNode(op_node);
break; break;
} }
} }
} }
static void GetInfoFromTheFirstOp(ir::Graph* graph, static void GetInfoFromTheTmpOp(ir::Graph* graph,
const std::string& flag, const std::string& flag,
const std::string& key_suffix, const std::string& key_suffix,
StringPairMap* info_map) { StringPairMap* info_map) {
VLOG(3) << "get variables from the first op's attr"; VLOG(3) << "get variables from the first op's attr";
const std::string unsigned_flag = "_unsigned"; const std::string unsigned_flag = "_unsigned";
const std::string suffix = "_" + key_suffix + "_" + flag; const std::string suffix = "_" + key_suffix + "_" + flag;
const std::string suffix_is_unsigned = suffix + unsigned_flag; const std::string suffix_is_unsigned = suffix + unsigned_flag;
for (auto* op_node : for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) { ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || if (!op_node->IsOp() || op_node->Op()->Type() != "save") continue;
op_node->Op()->Type() == "fetch")
continue;
auto* op_desc = op_node->Op(); auto* op_desc = op_node->Op();
if (op_desc->GetAttrIfExists<bool>(flag)) { if (op_desc->GetAttrIfExists<bool>(flag)) {
...@@ -150,6 +138,7 @@ static void GetInfoFromTheFirstOp(ir::Graph* graph, ...@@ -150,6 +138,7 @@ static void GetInfoFromTheFirstOp(ir::Graph* graph,
op_desc->RemoveAttr(vector_name); op_desc->RemoveAttr(vector_name);
} }
} }
graph->RemoveNode(op_node);
break; break;
} }
} }
......
...@@ -754,9 +754,9 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const { ...@@ -754,9 +754,9 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const {
UpdateActivations(graph); UpdateActivations(graph);
RemoveCtrlVars(graph); RemoveCtrlVars(graph);
// save var_quant_scales in the first op's attr // save var_quant_scales in the temporary save op's attr
// for compute_propagate_scales_mkldnn_pass // for compute_propagate_scales_mkldnn_pass
SaveInfoInTheFirstOp( SaveInfoInTheTmpOp(
graph, "has_quant_info", "var_quant_scales", var_quant_scales); graph, "has_quant_info", "var_quant_scales", var_quant_scales);
} }
......
...@@ -430,7 +430,6 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -430,7 +430,6 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("quant_dequant_mkldnn_pass"); passes_.push_back("quant_dequant_mkldnn_pass");
passes_.push_back("mkldnn_placement_pass"); passes_.push_back("mkldnn_placement_pass");
passes_.push_back("constant_folding_pass");
passes_.push_back("squeeze2_transpose2_onednn_fuse_pass"); passes_.push_back("squeeze2_transpose2_onednn_fuse_pass");
passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass");
...@@ -485,6 +484,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -485,6 +484,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass"); passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass");
passes_.push_back("int8_scale_calculation_mkldnn_pass"); passes_.push_back("int8_scale_calculation_mkldnn_pass");
passes_.push_back("params_quantization_mkldnn_pass"); passes_.push_back("params_quantization_mkldnn_pass");
passes_.push_back("constant_folding_pass");
} }
use_mkldnn_int8_ = true; use_mkldnn_int8_ = true;
#else #else
......
...@@ -116,10 +116,9 @@ class TestLstmModelPTQ(unittest.TestCase): ...@@ -116,10 +116,9 @@ class TestLstmModelPTQ(unittest.TestCase):
config.switch_ir_optim(True) config.switch_ir_optim(True)
config.enable_mkldnn() config.enable_mkldnn()
config.disable_mkldnn_fc_passes() # fc passes caused dnnl error config.disable_mkldnn_fc_passes() # fc passes caused dnnl error
config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass")
config.set_mkldnn_cache_capacity(mkldnn_cache_capacity) config.set_mkldnn_cache_capacity(mkldnn_cache_capacity)
if mode == "ptq": if mode == "ptq":
# This pass to work properly, must be added before fc_fuse_pass
config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass")
config.enable_quantizer() config.enable_quantizer()
config.quantizer_config().set_quant_data(warmup_data) config.quantizer_config().set_quant_data(warmup_data)
config.quantizer_config().set_quant_batch_size(1) config.quantizer_config().set_quant_batch_size(1)
...@@ -244,7 +243,7 @@ class TestLstmModelPTQ(unittest.TestCase): ...@@ -244,7 +243,7 @@ class TestLstmModelPTQ(unittest.TestCase):
) )
(quant_hx_acc, quant_ctc_acc, quant_fps) = self.run_program( (quant_hx_acc, quant_ctc_acc, quant_fps) = self.run_program(
quant_model + "_int8", quant_model,
infer_data, infer_data,
num_threads, num_threads,
mkldnn_cache_capacity, mkldnn_cache_capacity,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册