diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 6995412d055c6b83b3064ceee1410590e481caab..b650821a3d507137aa8709fba73b00cdaba8a09b 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -146,6 +146,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, float shift, std::string shift_attr_name) const { auto inputs = op->inputs; + auto var_names = op->Op()->Inputs().at(input_name); + std::vector unique_var_names; + for (unsigned i = 0; i < var_names.size(); i++) + if (std::find(unique_var_names.begin(), + unique_var_names.end(), + var_names[i]) == unique_var_names.end()) + unique_var_names.push_back(var_names[i]); + auto output = op->outputs[0]; PADDLE_ENFORCE_GE(inputs.size(), 1, @@ -163,7 +171,6 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, // create a quantize op desc prototype OpDesc q_desc; q_desc.SetType("quantize"); - std::vector quantize_out_nodes(inputs.size()); std::vector quantize_out_node_names(inputs.size()); @@ -171,25 +178,52 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX; float scale = scale_out * max; - for (size_t i = 0; i < inputs.size(); i++) { - // Create quantize output variable + for (size_t var_id = 0; var_id < unique_var_names.size(); var_id++) { + auto index = -1; + for (size_t it = 0; it < inputs.size(); it++) { + if (inputs[it]->Name() == unique_var_names[var_id]) index = it; + } + + if (index == -1) { + PADDLE_ENFORCE_NE(index, + -1, + platform::errors::InvalidArgument( + "Var(%s) isn't the input of the %s operator.", + unique_var_names[var_id], + op->Op()->Type())); + } + + auto* input = inputs.at(index); + VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); - quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc); - quantize_out_node_names[i] = quantize_out_nodes[i]->Name(); + quantize_out_nodes[var_id] = g->CreateVarNode(&quantize_out_desc); + quantize_out_node_names[var_id] = quantize_out_nodes[var_id]->Name(); q_desc.SetAttr("Scale", scale); q_desc.SetAttr("Shift", shift); - q_desc.SetInput("Input", std::vector({inputs[i]->Name()})); - q_desc.SetOutput("Output", - std::vector({quantize_out_node_names[i]})); + q_desc.SetInput("Input", std::vector({input->Name()})); + q_desc.SetOutput( + "Output", std::vector({quantize_out_node_names[var_id]})); q_desc.SetAttr("is_negative_input", !are_inputs_unsigned); auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. // link quantize op - UnlinkNodes(inputs[i], op); - IR_NODE_LINK_TO(inputs[i], quantize_op); - IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]); - IR_NODE_LINK_TO(quantize_out_nodes[i], op); + UnlinkNodes(input, op); + IR_NODE_LINK_TO(input, quantize_op); + IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[var_id]); + IR_NODE_LINK_TO(quantize_out_nodes[var_id], op); + } + + // If any inputs were duplicated, now you have to enter them in the correct + // order. + for (size_t i = unique_var_names.size(); i < var_names.size(); i++) { + auto index = std::find( + unique_var_names.begin(), unique_var_names.end(), var_names[i]); + if (index != unique_var_names.end()) { + auto id = std::distance(unique_var_names.begin(), index); + quantize_out_node_names[i] = quantize_out_nodes[id]->Name(); + IR_NODE_LINK_TO(quantize_out_nodes[id], op); + } } // update op's input @@ -252,6 +286,8 @@ void CPUQuantizePass::DequantizeOutputs(Graph* g, bool is_unsigned, std::string scale_attr_name) const { auto outputs = op->outputs; + auto var_names = op->Op()->Outputs().at(output_name); + PADDLE_ENFORCE_GE(outputs.size(), 1, platform::errors::InvalidArgument( @@ -259,37 +295,53 @@ void CPUQuantizePass::DequantizeOutputs(Graph* g, op->Name(), outputs.size())); - std::vector quantize_in_node_names(outputs.size()); + std::vector dequantize_in_node_names(outputs.size()); + std::vector dequantize_in_nodes(outputs.size()); unsigned max = is_unsigned ? U8_MAX : S8_MAX; float scale = scale_to_one * max; - for (size_t i = 0; i < outputs.size(); i++) { + for (size_t var_id = 0; var_id < var_names.size(); var_id++) { + auto index = -1; + for (size_t it = 0; it < outputs.size(); it++) { + if (outputs[it]->Name() == var_names[var_id]) index = it; + } + + if (index == -1) { + PADDLE_ENFORCE_NE(index, + -1, + platform::errors::InvalidArgument( + "Var(%s) isn't the input of the %s operator.", + var_names[var_id], + op->Op()->Type())); + } + + auto* output = outputs.at(index); + // Create dequantize input variable VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); - Node* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc); - quantize_in_node_names[i] = dequantize_in_node->Name(); + dequantize_in_nodes[var_id] = g->CreateVarNode(&dequantize_in_desc); + dequantize_in_node_names[var_id] = dequantize_in_nodes[var_id]->Name(); // create a dequantize op node for output. OpDesc deq_desc; deq_desc.SetType("dequantize"); - deq_desc.SetInput("Input", - std::vector({quantize_in_node_names[i]})); - deq_desc.SetOutput("Output", - std::vector({outputs[i]->Name()})); + deq_desc.SetInput( + "Input", std::vector({dequantize_in_node_names[var_id]})); + deq_desc.SetOutput("Output", std::vector({output->Name()})); deq_desc.SetAttr("Scale", scale); deq_desc.SetAttr("is_negative_input", !is_unsigned); auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied. // link dequantize op - UnlinkNodes(op, outputs[i]); - IR_NODE_LINK_TO(op, dequantize_in_node); - IR_NODE_LINK_TO(dequantize_in_node, dequantize_op); - IR_NODE_LINK_TO(dequantize_op, outputs[i]); + UnlinkNodes(op, output); + IR_NODE_LINK_TO(op, dequantize_in_nodes[var_id]); + IR_NODE_LINK_TO(dequantize_in_nodes[var_id], dequantize_op); + IR_NODE_LINK_TO(dequantize_op, output); } // update op's output - op->Op()->SetOutput(output_name, quantize_in_node_names); + op->Op()->SetOutput(output_name, dequantize_in_node_names); if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); } diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index 1be8a6ca44e0e3f7c0da729025aadf78c118c6a6..55de1efed7943e172a3bb3d54192d4a3a30ac472 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -881,6 +881,45 @@ TEST(CpuQuantizePass, multi_gru_3) { MainTestMultiGru(layers); } +static const std::initializer_list + variable_names_multi_inputs_outputs = {"a", "b", "c1", "c2", "d", "e"}; + +// a->Pool->b +// b->Split->c1, c2 +// (c1, c2, c1, c2)->Concat->d +// d->Pool->e +ProgramDesc BuildProgramDescMulti() { + ProgramDesc prog; + for (auto& v : variable_names_multi_inputs_outputs) { + prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32); + } + + SetOp(&prog, "pool2d", "Pool", {"a"}, {"b"}, true, "float32"); + SetOp(&prog, "split", "Split", {"b"}, {"c1", "c2"}, true, "int8"); + SetOp( + &prog, "concat", "Concat", {"c1", "c2", "c1", "c2"}, {"d"}, true, "int8"); + SetOp(&prog, "pool2d", "Pool2", {"d"}, {"e"}, true, "float32"); + + return prog; +} + +TEST(CpuQuantizePass, multi_inputs_outputs_ops) { + // a->QUANT1->Split + // b1->DEQUANT->OUT->QUANT + // b2->DEQUANT->OUT->QUANT + // (b1, b2, b1, b2)->Concat->c->DEQUANT->Pool->d + int added_nodes = 6 * 2; + std::unordered_map expected_operators = {{"pool2d", 2}, + {"concat", 1}, + {"split", 1}, + {"quantize", 3}, + {"dequantize", 3}}; + MainTest(BuildProgramDescMulti(), + variable_names_multi_inputs_outputs, + expected_operators, + added_nodes); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index b0ccbb8aa9d26cd2985c6648f8c15c3a1a0d2e94..69cf01278b33a63dade0922783a6411203d9c06d 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -158,6 +158,11 @@ void CPUQuantizeSquashPass::DequantQuantSquash( PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Scale")); float dequant_shift = dequant_op->Op()->GetAttrIfExists("Shift"); float quant_shift = quant_op->Op()->GetAttrIfExists("Shift"); + if (quant_op->Op()->GetAttrIfExists("is_negative_input") != + dequant_op->Op()->GetAttrIfExists("is_negative_input")) { + return; + } + PADDLE_ENFORCE_NE( nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(), @@ -169,14 +174,13 @@ void CPUQuantizeSquashPass::DequantQuantSquash( if (dequant_scale == quant_scale && dequant_shift == quant_shift) { // squash dequantize-quantize to nothing auto quant_out_var_name = quant_out->Name(); - auto next_op_inputs = next_op_desc->InputNames(); - for (const auto& name : next_op_inputs) { - auto input_names = next_op_desc->Input(name); + for (auto input_name : next_op_desc->InputNames()) { + auto& input_names = next_op_desc->MutableInputs()->at(input_name); std::replace(input_names.begin(), input_names.end(), quant_out_var_name, dequant_in->Name()); - next_op_desc->SetInput(name, input_names); + next_op_desc->SetInput(input_name, input_names); } if (keep_dequant) @@ -413,12 +417,11 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { // update the next operator input, // by replacing quant_out with first_quant_out - auto last_op_names = last_op->Op()->Input(last_op_input_name); - last_op_names.erase( - std::remove( - last_op_names.begin(), last_op_names.end(), quant_out->Name()), - last_op_names.end()); - last_op_names.push_back(first_quant_out->Name()); + auto last_op_names = last_op->Op()->Inputs().at(last_op_input_name); + std::replace(last_op_names.begin(), + last_op_names.end(), + quant_out->Name(), + first_quant_out->Name()); last_op_op->SetInput(last_op_input_name, std::vector(last_op_names));