未验证 提交 5fc92943 编写于 作者: Y yeliang2258 提交者: GitHub

Fix ComputePropagateScalesMkldnnPass of MKLDNN (#47574)

* add constant_folding_pass pass for mkldnn int8

* update UpdateScaleOpInOutScales
上级 3de3e45e
...@@ -336,27 +336,46 @@ void ComputePropagateScalesMkldnnPass::ComputeWeightScales( ...@@ -336,27 +336,46 @@ void ComputePropagateScalesMkldnnPass::ComputeWeightScales(
ComputeLstmWeightScales(graph, scope, "WeightX", "WeightH", var_quant_scales); ComputeLstmWeightScales(graph, scope, "WeightX", "WeightH", var_quant_scales);
} }
void ComputePropagateScalesMkldnnPass::UpdateScaleOpInScale( void ComputePropagateScalesMkldnnPass::UpdateScaleOpInOutScales(
Node* op_node, Node* op_node,
const std::string& input_name, const std::string& input_name,
const std::string& output_name, const std::string& output_name,
StringPairMap* var_quant_scales) const { StringPairMap* var_quant_scales) const {
auto iter = var_quant_scales->find(output_name); auto out_iter = var_quant_scales->find(output_name);
if (iter != var_quant_scales->end()) { auto input_iter = var_quant_scales->find(input_name);
auto pair = iter->second; // All the input and output have scales
const auto tensor = pair.second; if (out_iter != var_quant_scales->end() &&
input_iter != var_quant_scales->end()) {
const auto scale = PADDLE_GET_CONST(float, op_node->Op()->GetAttr("scale")); return;
phi::DenseTensor tmp_tensor; }
tmp_tensor.Resize(tensor.dims());
auto* data = tmp_tensor.mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < tensor.numel(); i++) {
data[i] = data[i] * scale;
}
auto new_pair = std::make_pair(pair.first, tmp_tensor); const auto scale = PADDLE_GET_CONST(float, op_node->Op()->GetAttr("scale"));
var_quant_scales->insert(std::make_pair(input_name, new_pair)); if (std::abs(scale) < 1e-6 && out_iter != var_quant_scales->end()) {
return;
} }
std::string name = input_name;
auto iter = out_iter;
if (input_iter != var_quant_scales->end()) {
iter = input_iter;
name = output_name;
}
phi::DenseTensor tmp_tensor;
auto pair = iter->second;
const auto tensor = pair.second;
tmp_tensor.Resize(tensor.dims());
auto* data = tmp_tensor.mutable_data<float>(platform::CPUPlace());
auto* src_data = tensor.data<float>();
for (int i = 0; i < tensor.numel(); i++) {
if (out_iter != var_quant_scales->end()) {
data[i] = src_data[i] / scale;
} else {
data[i] = src_data[i] * scale;
}
}
auto new_pair = std::make_pair(pair.first, tmp_tensor);
var_quant_scales->insert(std::make_pair(name, new_pair));
} }
std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales( std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
...@@ -403,10 +422,12 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales( ...@@ -403,10 +422,12 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
} }
} else if (op_name == "scale") { } else if (op_name == "scale") {
const std::string output_name = op_node->Op()->Output("Out")[0]; const std::string output_name = op_node->Op()->Output("Out")[0];
const std::string input_name = op_node->Op()->Input("X")[0];
auto out_iter = var_quant_scales->find(output_name); auto out_iter = var_quant_scales->find(output_name);
if (out_iter != var_quant_scales->end()) { auto input_iter = var_quant_scales->find(input_name);
const std::string input_name = op_node->Op()->Input("X")[0]; if (out_iter != var_quant_scales->end() ||
UpdateScaleOpInScale( input_iter != var_quant_scales->end()) {
UpdateScaleOpInOutScales(
op_node, input_name, output_name, var_quant_scales); op_node, input_name, output_name, var_quant_scales);
} }
} }
......
...@@ -79,10 +79,10 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase { ...@@ -79,10 +79,10 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase {
void UpdateReluOutputScales(ir::Graph* graph, void UpdateReluOutputScales(ir::Graph* graph,
StringPairMap* var_quant_scales) const; StringPairMap* var_quant_scales) const;
void UpdateScaleOpInScale(Node* op_node, void UpdateScaleOpInOutScales(Node* op_node,
const std::string& input_name, const std::string& input_name,
const std::string& output_name, const std::string& output_name,
StringPairMap* var_quant_scales) const; StringPairMap* var_quant_scales) const;
std::unordered_set<std::string> UpdateScales( std::unordered_set<std::string> UpdateScales(
ir::Graph* graph, ir::Graph* graph,
......
...@@ -384,6 +384,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -384,6 +384,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("quant_dequant_mkldnn_pass"); passes_.push_back("quant_dequant_mkldnn_pass");
passes_.push_back("mkldnn_placement_pass"); passes_.push_back("mkldnn_placement_pass");
passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("constant_folding_pass");
passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass");
passes_.push_back("seqconv_eltadd_relu_fuse_pass"); passes_.push_back("seqconv_eltadd_relu_fuse_pass");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册