未验证 提交 559b9754 编写于 作者: Y yeliang2258 提交者: GitHub

Fix ComputePropagateScalesMkldnnPass of MKLDNN (#47574) (#47639)

* add constant_folding_pass pass for mkldnn int8

* update UpdateScaleOpInOutScales
上级 75088bbf
...@@ -336,27 +336,45 @@ void ComputePropagateScalesMkldnnPass::ComputeWeightScales( ...@@ -336,27 +336,45 @@ void ComputePropagateScalesMkldnnPass::ComputeWeightScales(
ComputeLstmWeightScales(graph, scope, "WeightX", "WeightH", var_quant_scales); ComputeLstmWeightScales(graph, scope, "WeightX", "WeightH", var_quant_scales);
} }
void ComputePropagateScalesMkldnnPass::UpdateScaleOpInScale( void ComputePropagateScalesMkldnnPass::UpdateScaleOpInOutScales(
Node* op_node, Node* op_node,
const std::string& input_name, const std::string& input_name,
const std::string& output_name, const std::string& output_name,
StringPairMap* var_quant_scales) const { StringPairMap* var_quant_scales) const {
auto iter = var_quant_scales->find(output_name); auto out_iter = var_quant_scales->find(output_name);
if (iter != var_quant_scales->end()) { auto input_iter = var_quant_scales->find(input_name);
// All the input and output have scales
if (out_iter != var_quant_scales->end() &&
input_iter != var_quant_scales->end()) {
return;
}
const auto scale = PADDLE_GET_CONST(float, op_node->Op()->GetAttr("scale"));
if (std::abs(scale) < 1e-6 && out_iter != var_quant_scales->end()) {
return;
}
std::string name = input_name;
auto iter = out_iter;
if (input_iter != var_quant_scales->end()) {
iter = input_iter;
name = output_name;
}
phi::DenseTensor tmp_tensor;
auto pair = iter->second; auto pair = iter->second;
const auto tensor = pair.second; const auto tensor = pair.second;
const auto scale = PADDLE_GET_CONST(float, op_node->Op()->GetAttr("scale"));
Tensor tmp_tensor;
tmp_tensor.Resize(tensor.dims()); tmp_tensor.Resize(tensor.dims());
auto* data = tmp_tensor.mutable_data<float>(platform::CPUPlace()); auto* data = tmp_tensor.mutable_data<float>(platform::CPUPlace());
auto* src_data = tensor.data<float>();
for (int i = 0; i < tensor.numel(); i++) { for (int i = 0; i < tensor.numel(); i++) {
data[i] = data[i] * scale; if (out_iter != var_quant_scales->end()) {
data[i] = src_data[i] / scale;
} else {
data[i] = src_data[i] * scale;
} }
auto new_pair = std::make_pair(pair.first, tmp_tensor);
var_quant_scales->insert(std::make_pair(input_name, new_pair));
} }
auto new_pair = std::make_pair(pair.first, tmp_tensor);
var_quant_scales->insert(std::make_pair(name, new_pair));
} }
std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales( std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
...@@ -403,10 +421,12 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales( ...@@ -403,10 +421,12 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
} }
} else if (op_name == "scale") { } else if (op_name == "scale") {
const std::string output_name = op_node->Op()->Output("Out")[0]; const std::string output_name = op_node->Op()->Output("Out")[0];
auto out_iter = var_quant_scales->find(output_name);
if (out_iter != var_quant_scales->end()) {
const std::string input_name = op_node->Op()->Input("X")[0]; const std::string input_name = op_node->Op()->Input("X")[0];
UpdateScaleOpInScale( auto out_iter = var_quant_scales->find(output_name);
auto input_iter = var_quant_scales->find(input_name);
if (out_iter != var_quant_scales->end() ||
input_iter != var_quant_scales->end()) {
UpdateScaleOpInOutScales(
op_node, input_name, output_name, var_quant_scales); op_node, input_name, output_name, var_quant_scales);
} }
} }
......
...@@ -79,7 +79,7 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase { ...@@ -79,7 +79,7 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase {
void UpdateReluOutputScales(ir::Graph* graph, void UpdateReluOutputScales(ir::Graph* graph,
StringPairMap* var_quant_scales) const; StringPairMap* var_quant_scales) const;
void UpdateScaleOpInScale(Node* op_node, void UpdateScaleOpInOutScales(Node* op_node,
const std::string& input_name, const std::string& input_name,
const std::string& output_name, const std::string& output_name,
StringPairMap* var_quant_scales) const; StringPairMap* var_quant_scales) const;
......
...@@ -376,6 +376,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -376,6 +376,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("quant_dequant_mkldnn_pass"); passes_.push_back("quant_dequant_mkldnn_pass");
passes_.push_back("mkldnn_placement_pass"); passes_.push_back("mkldnn_placement_pass");
passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("constant_folding_pass");
passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass");
passes_.push_back("seqconv_eltadd_relu_fuse_pass"); passes_.push_back("seqconv_eltadd_relu_fuse_pass");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册