diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc index df19bc9ade8d57b840306e70d877e3abc86d8a1f..477709466a71c76fdf126ec631471f159d773e3a 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc @@ -19,7 +19,6 @@ #include #include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h" #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -394,8 +393,13 @@ std::unordered_set ComputePropagateScalesMkldnnPass::UpdateScales( auto out_iter = var_quant_scales->find(op_node->Op()->Output("Out")[0]); if (out_iter != var_quant_scales->end()) { std::vector input_names = op_node->Op()->Input("X"); - for (auto input_name : input_names) - (*var_quant_scales)[input_name] = out_iter->second; + for (auto input_name : input_names) { + auto concat_in_iter = var_quant_scales->find(input_name); + if (concat_in_iter == var_quant_scales->end()) + (*var_quant_scales)[input_name] = out_iter->second; + else + (*var_quant_scales)[input_name].second = out_iter->second.second; + } } } else if (op_name == "scale") { const std::string output_name = op_node->Op()->Output("Out")[0]; @@ -409,6 +413,40 @@ std::unordered_set ComputePropagateScalesMkldnnPass::UpdateScales( } return waiting_for_scale; } +void ComputePropagateScalesMkldnnPass::UpdateReluOutputScales( + ir::Graph* graph, StringPairMap* var_quant_scales) const { + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp()) continue; + auto op = op_node->Op(); + bool is_unsigned = false; + std::string output_name = "Out"; + std::string act_name; + if (op->Type() == "relu") { + is_unsigned = true; + } else { + if (op->Type() == "conv2d") { + act_name = "fuse_activation"; + output_name = "Output"; + } else if (op->Type() == "fc") { + act_name = "activation_type"; + } + if (!act_name.empty()) { + auto act = op->GetAttrIfExists(act_name); + if (act == "relu" || act == "relu6") { + is_unsigned = true; + } + } + } + if (is_unsigned) { + std::string output_var_name = op->Output(output_name)[0]; + auto out_iter = var_quant_scales->find(output_var_name); + if (out_iter != var_quant_scales->end()) { + (*var_quant_scales)[output_var_name].first = true; + } + } + } +} void ComputePropagateScalesMkldnnPass::PropagateScales( ir::Graph* graph, @@ -427,21 +465,6 @@ void ComputePropagateScalesMkldnnPass::PropagateScales( } } -void ComputePropagateScalesMkldnnPass::ConvertStringPairMap( - const StringPairMap& var_quant_scales, - std::unordered_map>* info_map) const { - for (auto iter = var_quant_scales.begin(); iter != var_quant_scales.end(); - iter++) { - auto* data = iter->second.second.data(); - std::vector data_v; - for (int i = 0; i < iter->second.second.numel(); i++) { - data_v.push_back(data[i]); - } - - info_map->insert(std::make_pair(iter->first, data_v)); - } -} - void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Convert paddle model to mkldnn quantized model."; const std::string pattern_name = "compute_propagate_scales_mkldnn_pass"; @@ -461,13 +484,13 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const { auto* scope = param_scope(); GetQuantInfo(graph, &var_quant_scales); ComputeWeightScales(graph, scope, &var_quant_scales); + UpdateReluOutputScales(graph, &var_quant_scales); PropagateScales(graph, &var_quant_scales, scale_immutable_ops); // save var_quant_scales in the first op's attr // for cpu_quantize_pass - std::unordered_map> info_map; - ConvertStringPairMap(var_quant_scales, &info_map); - SaveInfoInTheFirstOp(graph, "has_quant_info", "var_quant_scales", info_map); + SaveInfoInTheFirstOp( + graph, "has_quant_info", "var_quant_scales", var_quant_scales); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h index ecc3ad16a54e6a7cb0aec7d8eee2653352976448..bae810746ae2df1c92198080ff5367a64140063c 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h @@ -17,14 +17,12 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h" namespace paddle { namespace framework { namespace ir { -using StringPairMap = - std::unordered_map>; - class ComputePropagateScalesMkldnnPass : public FusePassBase { public: ComputePropagateScalesMkldnnPass() = default; @@ -78,6 +76,9 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase { Scope* scope, StringPairMap* var_quant_scales) const; + void UpdateReluOutputScales(ir::Graph* graph, + StringPairMap* var_quant_scales) const; + void UpdateScaleOpInScale(Node* op_node, const std::string& input_name, const std::string& output_name, @@ -92,10 +93,6 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase { ir::Graph* graph, StringPairMap* var_quant_scales, const std::unordered_set& scale_immutable_ops) const; - - void ConvertStringPairMap( - const StringPairMap& var_quant_scales, - std::unordered_map>* info_map) const; }; } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass_tester.cc index 03c01507ca27d2504332559c22f4a4f1954ccfb6..39ecfd2c0e79a543053072bd032dbd253f9fd14e 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass_tester.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include "paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h" #include "paddle/fluid/framework/naive_executor.h" @@ -91,11 +92,16 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test { graph, scope, wx_name, wh_name, var_quant_scales); } + void UpdateReluOutputScales(ir::Graph* graph, + StringPairMap* var_quant_scales) const { + pass->UpdateReluOutputScales(graph, var_quant_scales); + } + void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, const std::string& var_name) { auto x = scope->Var(var_name); - auto tensor = x->GetMutable(); + auto tensor = x->GetMutable(); auto tensor_size = 1; if (var_name == "filter") { tensor_size = positive_and_negative_values.size(); @@ -124,7 +130,6 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test { } void ComputeRnnWeightScalesTest(const std::string& type, - const std::initializer_list& ops, const framework::ProgramDesc& prog, std::vector scales) { ir::Graph* graph(new ir::Graph(prog)); @@ -140,7 +145,7 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test { StringPairMap var_quant_scales; auto* wx_var = scope.FindVar(wx_var_names); - auto* wx_tensor = wx_var->GetMutable(); + auto* wx_tensor = wx_var->GetMutable(); wx_tensor->Resize(phi::make_dim(wx.size(), wx[0].size())); for (size_t i = 0; i < wx.size(); i++) std::copy(begin(wx[i]), @@ -149,7 +154,7 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test { i * wx[0].size()); auto* wh_var = scope.FindVar(wh_var_names); - auto* wh_tensor = wh_var->GetMutable(); + auto* wh_tensor = wh_var->GetMutable(); wh_tensor->Resize(phi::make_dim(wh.size(), wh[0].size())); for (size_t i = 0; i < wh.size(); i++) std::copy(begin(wh[i]), @@ -174,6 +179,24 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test { } } + void UpdateReluOutputScaleTest( + const framework::ProgramDesc& prog, + StringPairMap* var_quant_scales, + const std::initializer_list& variable_names) { + ir::Graph* graph(new ir::Graph(prog)); + Scope scope; + + PrepareGraph(graph, prog, &scope, conv_variable_names); + + UpdateReluOutputScales(graph, var_quant_scales); + + for (auto& var_name : variable_names) { + auto iter = var_quant_scales->find(var_name); + ASSERT_NE(iter, var_quant_scales->end()); + ASSERT_EQ((*var_quant_scales)[var_name].first, true); + } + } + private: std::unique_ptr pass; }; @@ -182,11 +205,15 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, const std::vector& inputs, - const std::vector& outputs) { + const std::vector& outputs, + const std::unordered_map& attrs = {}) { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); op->SetAttr("use_mkldnn", true); op->SetAttr("name", name); + if (!attrs.empty()) + for (auto& attr : attrs) op->SetAttr(attr.first, attr.second); + if (type == "conv2d") { op->SetInput("Input", {inputs[0]}); if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]}); @@ -211,6 +238,23 @@ ProgramDesc BuildConv2dProgramDesc() { return prog; } +ProgramDesc BuildConv2dReluProgramDesc() { + ProgramDesc prog; + for (auto& v : conv_variable_names) { + prog.MutableBlock(0)->Var(v); + } + std::unordered_map attrs = { + {"fuse_activation", "relu"}}; + SetOp(&prog, + "conv2d", + "Conv2d", + {"conv_in", "filter", "bias"}, + {"conv_out"}, + attrs); + + return prog; +} + ProgramDesc BuildFusionGruProgramDesc() { ProgramDesc prog; for (auto& v : rnn_variable_names) { @@ -262,7 +306,7 @@ TEST_F(ComputePropagateScalesMkldnnPassTest, compute_var_scales) { StringPairMap var_quant_scales; auto* var = scope.FindVar(weight_var_name); - auto* weight_tensor = var->GetMutable(); + auto* weight_tensor = var->GetMutable(); weight_tensor->Resize(phi::make_dim(1, values.size())); std::copy(begin(values), end(values), @@ -283,15 +327,24 @@ TEST_F(ComputePropagateScalesMkldnnPassTest, compute_var_scales) { } TEST_F(ComputePropagateScalesMkldnnPassTest, compute_gru_weight_scales) { - ComputeRnnWeightScalesTest("gru", - {"fusion_gru", "multi_gru"}, - BuildFusionGruProgramDesc(), - gru_scales); + ComputeRnnWeightScalesTest("gru", BuildFusionGruProgramDesc(), gru_scales); } TEST_F(ComputePropagateScalesMkldnnPassTest, compute_lstm_weight_scales) { - ComputeRnnWeightScalesTest( - "lstm", {"fusion_lstm"}, BuildFusionLstmProgramDesc(), lstm_scales); + ComputeRnnWeightScalesTest("lstm", BuildFusionLstmProgramDesc(), lstm_scales); +} + +TEST_F(ComputePropagateScalesMkldnnPassTest, update_relu_output_scales) { + StringPairMap var_quant_scales; + for (auto& var_name : conv_variable_names) { + phi::DenseTensor tensor; + auto* data = tensor.mutable_data({1}, platform::CPUPlace()); + data[0] = 10; + auto pair = std::make_pair(false, tensor); + var_quant_scales.insert(std::make_pair(var_name, pair)); + } + UpdateReluOutputScaleTest( + BuildConv2dReluProgramDesc(), &var_quant_scales, {"conv_out"}); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 5ec22e2e88a1e1079c9880ca12303032afafea49..9e7ba25755c4c673e51a73c66cad84743473c9a4 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -229,6 +229,7 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, std::vector({dequantize_in_node->Name()})); deq_desc.SetOutput("Output", std::vector({output->Name()})); deq_desc.SetAttr("Scale", scale); + deq_desc.SetAttr("is_negative_input", !is_unsigned); auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied. // update op's output @@ -332,20 +333,8 @@ bool CPUQuantizePass::IsOpQuantized(const Node* node) const { } void CPUQuantizePass::GetQuantInfo(Graph* graph) const { - std::unordered_map> info_map{}; - GetInfoFromTheFirstOp(graph, "has_quant_info", "var_quant_scales", &info_map); - - for (auto iter = info_map.begin(); iter != info_map.end(); iter++) { - LoDTensor tensor; - const int size = static_cast(iter->second.size()); - auto* data = tensor.mutable_data({size}, platform::CPUPlace()); - for (int i = 0; i < size; i++) { - data[i] = static_cast(iter->second[i]); - } - - auto pair = std::make_pair(false, tensor); - var_quant_scales_->insert(std::make_pair(iter->first, pair)); - } + GetInfoFromTheFirstOp( + graph, "has_quant_info", "var_quant_scales", var_quant_scales_); } void CPUQuantizePass::QuantizeConv(Graph* graph, @@ -593,6 +582,20 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const { return; } + bool are_all_inputs_unsigned{true}; + // if all inputs were unsigned, then the output was set to unsigned + // during the scale calculation step + auto inputs = concat_op->inputs; + for (size_t i = 0; i < inputs.size(); i++) { + if (AreScalesPresentForVarNames({inputs[i]->Name()})) { + auto scale_data = GetScaleDataByName(inputs[i]->Name()); + if (scale_data.first == false) { + are_all_inputs_unsigned = false; + break; + } + } + } + GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern); if (!AreScalesPresentForNodes({concat_out})) { @@ -601,17 +604,12 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const { return; } - // if all inputs were unsigned, then the output was set to unsigned - // during the scale calculation step - bool are_all_inputs_unsigned{false}; - auto output_scale = - GetScaleValueForNode(concat_out, &are_all_inputs_unsigned); + auto output_scale = GetScaleValueForNode(concat_out); QuantizeInputs(g, concat_op, "X", are_all_inputs_unsigned); DequantizeOutput( g, concat_op, concat_out, "Out", output_scale, are_all_inputs_unsigned); - ++quantize_concat_count; }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 933d60b0a2739432fbda090535d89fb1cdafb220..e0a64b2036bb7a048d706e77eac0acefc6964fdf 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -105,51 +105,24 @@ void CPUQuantizeSquashPass::FindNodesToKeep( AddStatis(found_count); } -bool CPUQuantizeSquashPass::IsDequantizeInputUint8( - const Node* dequant_in) const { - PADDLE_ENFORCE_EQ( - dequant_in->inputs.size(), - 1, - platform::errors::InvalidArgument( - "Dequantize (id: %f) should have only one input.", dequant_in->id())); - if (dequant_in->inputs[0]->IsOp()) { - auto prev_op = dequant_in->inputs[0]->Op(); - std::string act_name; - if (prev_op->Type() == "relu") { - return true; - } else { - if (prev_op->Type() == "conv2d") { - act_name = "fuse_activation"; - } else if (prev_op->Type() == "fc") { - act_name = "activation_type"; - } - if (!act_name.empty()) { - auto act = prev_op->GetAttrIfExists(act_name); - if (act == "relu" || act == "relu6") { - return true; - } - } - } - } - return false; -} - bool CPUQuantizeSquashPass::IsDequantizeQuantizeIncompatible( - Node* quant_op, Node* dequant_in, Node* next_op) const { - bool is_concat_signed = + Node* quant_op, Node* dequant_op, Node* next_op) const { + bool is_next_op_signed = quant_op->Op()->GetAttrIfExists("is_negative_input"); - bool is_input_unsigned = IsDequantizeInputUint8(dequant_in); + bool is_input_signed = + dequant_op->Op()->GetAttrIfExists("is_negative_input"); + /* TODO(sfraczek): remove elementwise from this condition when BinaryMKLDNN kernel will support two different input data types */ bool is_next_op_concat_or_elementwise = next_op->Op()->Type() == "concat" || next_op->Op()->Type().find("elementwise") == 0; - if (is_next_op_concat_or_elementwise && is_concat_signed && - is_input_unsigned) { + if (is_next_op_concat_or_elementwise && + (is_next_op_signed ^ is_input_signed)) { VLOG(4) << "Do not squash dequant-quant, because " << "next_op is: " << next_op->Op()->Type() - << ", is_concat_signed: " << is_concat_signed - << ", is_input_unsigned: " << is_input_unsigned << "."; + << ", is_next_op_signed: " << is_next_op_signed + << ", is_input_signed: " << is_input_signed << "."; return true; } return false; @@ -174,7 +147,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash( GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern); GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern); - if (IsDequantizeQuantizeIncompatible(quant_op, dequant_in, next_op)) { + if (IsDequantizeQuantizeIncompatible(quant_op, dequant_op, next_op)) { return; } diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index 5207cc519c698090ba4e6c3bc37d2ede79b5a852..3aed54609d4512f94f716adb11ce1e4ebf55994f 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -43,11 +43,6 @@ class CPUQuantizeSquashPass : public FusePassBase { Graph* graph, std::unordered_map* nodes_keep_counter) const; - /* - * Check if input to dequantize is uint8 - */ - bool IsDequantizeInputUint8(const Node* dequant_in) const; - /* * Don't squash unsigned dequantize with signed quantize. * This is important for concat and elementwise ops. @@ -55,7 +50,7 @@ class CPUQuantizeSquashPass : public FusePassBase { * elementwise assumes first input type. */ bool IsDequantizeQuantizeIncompatible(Node* quant_op, - Node* dequant_in, + Node* dequant_op, Node* next_op) const; /* diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index 655cc95bf28a0506d24f6924a3bc5bd5ee2061a7..cd71ff153d60101866d93a144016c592478ddb0b 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -68,15 +68,11 @@ void SetOp(ProgramDesc* prog, op->SetAttr("padding_algorithm", std::string("EXPLICIT")); op->SetAttr("data_format", std::string("NCHW")); op->SetAttr("force_fp32_output", false); - } else if (type == "quantize") { + } else if (type == "quantize" || type == "dequantize") { op->SetInput("Input", {inputs[0]}); op->SetOutput("Output", {outputs[0]}); op->SetAttr("Scale", scale[0]); op->SetAttr("is_negative_input", is_negative_input); - } else if (type == "dequantize") { - op->SetInput("Input", {inputs[0]}); - op->SetOutput("Output", {outputs[0]}); - op->SetAttr("Scale", scale[0]); } else if (type == "requantize") { op->SetInput("Input", {inputs[0]}); op->SetOutput("Output", {outputs[0]}); @@ -303,31 +299,22 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, return prog; } -/* a->relu->b->Dequant->c(u8)->Quant->d-\ - * e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x - * i->relu->j->Dequant->k(u8)->Quant->l-/ +/* a->relu->b->Dequant(u8)->c->Quant(u8)->d-\ + * e->relu->f->Dequant(u8)->g->Quant(u8)->h--Concat1->i */ -ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) { +ProgramDesc BuildU8U8ConcatProgramDesc(float scale_out, float scale) { ProgramDesc prog; for (auto& v : variable_names) { prog.MutableBlock(0)->Var(v); } SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out}); SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out}); - SetOp(&prog, "relu", "Relu3", {"i"}, {"j"}, true, {scale, scale_out}); - - SetOp( - &prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, {scale, scale_out}); - SetOp( - &prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, {scale, scale_out}); - SetOp( - &prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, {scale, scale_out}); SetOp(&prog, - "quantize", - "Quant1", + "dequantize", + "Dequant1", + {"b"}, {"c"}, - {"d"}, true, {scale, scale_out}, 0.0f, @@ -336,10 +323,23 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) { 1, false); // is_negative_input = false SetOp(&prog, - "quantize", - "Quant2", + "dequantize", + "Dequant2", + {"f"}, {"g"}, - {"h"}, + true, + {scale, scale_out}, + 0.0f, + "float32", + false, + 1, + false); // is_negative_input = false + + SetOp(&prog, + "quantize", + "Quant1", + {"c"}, + {"d"}, true, {scale, scale_out}, 0.0f, @@ -349,9 +349,9 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) { false); // is_negative_input = false SetOp(&prog, "quantize", - "Quant3", - {"k"}, - {"l"}, + "Quant2", + {"g"}, + {"h"}, true, {scale, scale_out}, 0.0f, @@ -360,27 +360,47 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) { 1, false); // is_negative_input = false - SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true); + SetOp(&prog, "concat", "Concat1", {"d", "h"}, {"i"}, true); return prog; } -/* a->relu->b->Dequant->c(u8)->Quant->d-\ - * e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x - * i->pool2d->j->Dequant->k(s8)->Quant->l-/ +/* a->relu->b->Dequant(u8)->c->Quant(s8)->d-\ + * e->relu->f->Dequant(u8)->g->Quant(s8)->h--Concat1->x + * i->pool2d->j->Dequant(s8)->k->Quant(s8)->l-/ */ ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) { ProgramDesc prog; for (auto& v : variable_names) { prog.MutableBlock(0)->Var(v); } - SetOp(&prog, "relu", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out}); - SetOp(&prog, "relu", "Relu1", {"e"}, {"f"}, true, {scale, scale_out}); + SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out}); + SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out}); SetOp(&prog, "pool2d", "Pool2d2", {"i"}, {"j"}, true, {scale, scale_out}); - SetOp( - &prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, {scale, scale_out}); - SetOp( - &prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, {scale, scale_out}); + SetOp(&prog, + "dequantize", + "Dequant1", + {"b"}, + {"c"}, + true, + {scale, scale_out}, + 0.0f, + "float32", + false, + 1, + false); // is_negative_input = false + SetOp(&prog, + "dequantize", + "Dequant2", + {"f"}, + {"g"}, + true, + {scale, scale_out}, + 0.0f, + "float32", + false, + 1, + false); // is_negative_input = false SetOp( &prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, {scale, scale_out}); @@ -392,9 +412,9 @@ ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) { return prog; } -/* a->pool2d->b->Dequant->c(s8)->Quant->d-\ - * e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x - * i->pool2d->j->Dequant->k(s8)->Quant->l-/ +/* a->pool2d->b->Dequant(s8)->c->Quant(s8)->d-\ + * e->relu->f->Dequant(u8)->g->Quant(s8)->h--Concat1->x + * i->pool2d->j->Dequant(s8)->k->Quant(s8)->l-/ */ ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) { ProgramDesc prog; @@ -407,8 +427,18 @@ ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) { SetOp( &prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, {scale, scale_out}); - SetOp( - &prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, {scale, scale_out}); + SetOp(&prog, + "dequantize", + "Dequant2", + {"f"}, + {"g"}, + true, + {scale, scale_out}, + 0.0f, + "float32", + false, + 1, + false); // is_negative_input = false SetOp( &prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, {scale, scale_out}); @@ -1141,13 +1171,12 @@ TEST(CpuQuantizeSquashPass, squash_all_s8_input_to_concat1) { } TEST(CpuQuantizeSquashPass, squash_all_u8_input_to_concat2) { - // removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out) - auto remove_nodes = 12; + // removed 2 x 4 (dequantize_op, dequantize_out, quantize, quantize_out) + auto remove_nodes = 8; std::unordered_map expected_operators = { - {"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"relu", 3}}; - CheckNodesTest(BuildU8U8U8ConcatProgramDesc(1.2f, 1.2f), - expected_operators, - remove_nodes); + {"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"relu", 2}}; + CheckNodesTest( + BuildU8U8ConcatProgramDesc(1.2f, 1.2f), expected_operators, remove_nodes); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h b/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h index a714f236c461656585686ee61d3d471df5e89d97..6899a7202da9cc3734dc01d6ab7c24d30a6a5606 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h @@ -22,6 +22,9 @@ namespace paddle { namespace framework { namespace ir { +using StringPairMap = + std::unordered_map>; + static void SaveInfoInTheFirstOp( ir::Graph* graph, const std::string& flag, @@ -44,6 +47,31 @@ static void SaveInfoInTheFirstOp( } } +static void SaveInfoInTheFirstOp(ir::Graph* graph, + const std::string& flag, + const std::string& key_suffix, + const StringPairMap& info_map) { + VLOG(3) << "save variables in the first op's attr"; + + const std::string suffix = "_" + key_suffix + "_" + flag; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || + op_node->Op()->Type() == "fetch") + continue; + + op_node->Op()->SetAttr(flag, true); + for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) { + auto* data = iter->second.second.data(); + std::vector data_v(data, data + iter->second.second.numel()); + op_node->Op()->SetAttr(iter->first + suffix + "_unsigned", + iter->second.first); + op_node->Op()->SetAttr(iter->first + suffix, data_v); + } + break; + } +} + static void GetInfoFromTheFirstOp( ir::Graph* graph, const std::string& flag, @@ -77,6 +105,54 @@ static void GetInfoFromTheFirstOp( } } +static void GetInfoFromTheFirstOp(ir::Graph* graph, + const std::string& flag, + const std::string& key_suffix, + StringPairMap* info_map) { + VLOG(3) << "get variables from the first op's attr"; + const std::string unsigned_flag = "_unsigned"; + const std::string suffix = "_" + key_suffix + "_" + flag; + const std::string suffix_is_unsigned = suffix + unsigned_flag; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || + op_node->Op()->Type() == "fetch") + continue; + + auto* op_desc = op_node->Op(); + if (op_desc->GetAttrIfExists(flag)) { + op_desc->RemoveAttr(flag); + std::vector attr_names = op_desc->AttrNames(); + for (auto fake_name : attr_names) { + auto is_unsigned = false; + size_t pos = fake_name.find(suffix_is_unsigned); + + if (pos != std::string::npos) { + std::string unsigned_var_name = fake_name; + is_unsigned = + PADDLE_GET_CONST(bool, op_desc->GetAttr(unsigned_var_name)); + + std::string var_name = fake_name.substr(0, pos); + size_t unsigned_pos = fake_name.find(unsigned_flag); + std::string vector_name = + fake_name.erase(unsigned_pos, unsigned_flag.length()); + auto scales_vector = PADDLE_GET_CONST(std::vector, + op_desc->GetAttr(vector_name)); + phi::DenseTensor tensor; + const int size = static_cast(scales_vector.size()); + auto data = tensor.mutable_data({size}, platform::CPUPlace()); + std::copy(scales_vector.begin(), scales_vector.end(), data); + auto pair = std::make_pair(is_unsigned, tensor); + info_map->insert(std::make_pair(var_name, pair)); + op_desc->RemoveAttr(unsigned_var_name); + op_desc->RemoveAttr(vector_name); + } + } + break; + } + } +} + } // namespace ir } // namespace framework } // namespace paddle