未验证 提交 5fc92943 编写于 作者: Y yeliang2258 提交者: GitHub

Fix ComputePropagateScalesMkldnnPass of MKLDNN (#47574)

* add constant_folding_pass pass for mkldnn int8

* update UpdateScaleOpInOutScales
上级 3de3e45e
......@@ -336,27 +336,46 @@ void ComputePropagateScalesMkldnnPass::ComputeWeightScales(
ComputeLstmWeightScales(graph, scope, "WeightX", "WeightH", var_quant_scales);
}
void ComputePropagateScalesMkldnnPass::UpdateScaleOpInScale(
void ComputePropagateScalesMkldnnPass::UpdateScaleOpInOutScales(
Node* op_node,
const std::string& input_name,
const std::string& output_name,
StringPairMap* var_quant_scales) const {
auto iter = var_quant_scales->find(output_name);
if (iter != var_quant_scales->end()) {
auto pair = iter->second;
const auto tensor = pair.second;
auto out_iter = var_quant_scales->find(output_name);
auto input_iter = var_quant_scales->find(input_name);
// All the input and output have scales
if (out_iter != var_quant_scales->end() &&
input_iter != var_quant_scales->end()) {
return;
}
const auto scale = PADDLE_GET_CONST(float, op_node->Op()->GetAttr("scale"));
if (std::abs(scale) < 1e-6 && out_iter != var_quant_scales->end()) {
return;
}
std::string name = input_name;
auto iter = out_iter;
if (input_iter != var_quant_scales->end()) {
iter = input_iter;
name = output_name;
}
phi::DenseTensor tmp_tensor;
auto pair = iter->second;
const auto tensor = pair.second;
tmp_tensor.Resize(tensor.dims());
auto* data = tmp_tensor.mutable_data<float>(platform::CPUPlace());
auto* src_data = tensor.data<float>();
for (int i = 0; i < tensor.numel(); i++) {
data[i] = data[i] * scale;
if (out_iter != var_quant_scales->end()) {
data[i] = src_data[i] / scale;
} else {
data[i] = src_data[i] * scale;
}
auto new_pair = std::make_pair(pair.first, tmp_tensor);
var_quant_scales->insert(std::make_pair(input_name, new_pair));
}
auto new_pair = std::make_pair(pair.first, tmp_tensor);
var_quant_scales->insert(std::make_pair(name, new_pair));
}
std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
......@@ -403,10 +422,12 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
}
} else if (op_name == "scale") {
const std::string output_name = op_node->Op()->Output("Out")[0];
auto out_iter = var_quant_scales->find(output_name);
if (out_iter != var_quant_scales->end()) {
const std::string input_name = op_node->Op()->Input("X")[0];
UpdateScaleOpInScale(
auto out_iter = var_quant_scales->find(output_name);
auto input_iter = var_quant_scales->find(input_name);
if (out_iter != var_quant_scales->end() ||
input_iter != var_quant_scales->end()) {
UpdateScaleOpInOutScales(
op_node, input_name, output_name, var_quant_scales);
}
}
......
......@@ -79,7 +79,7 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase {
void UpdateReluOutputScales(ir::Graph* graph,
StringPairMap* var_quant_scales) const;
void UpdateScaleOpInScale(Node* op_node,
void UpdateScaleOpInOutScales(Node* op_node,
const std::string& input_name,
const std::string& output_name,
StringPairMap* var_quant_scales) const;
......
......@@ -384,6 +384,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("quant_dequant_mkldnn_pass");
passes_.push_back("mkldnn_placement_pass");
passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("constant_folding_pass");
passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass");
passes_.push_back("seqconv_eltadd_relu_fuse_pass");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册