未验证 提交 0ffba1c9 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Correct multiple inputs and outputs (#48872)

上级 428fb804
......@@ -146,6 +146,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g,
float shift,
std::string shift_attr_name) const {
auto inputs = op->inputs;
auto var_names = op->Op()->Inputs().at(input_name);
std::vector<std::string> unique_var_names;
for (unsigned i = 0; i < var_names.size(); i++)
if (std::find(unique_var_names.begin(),
unique_var_names.end(),
var_names[i]) == unique_var_names.end())
unique_var_names.push_back(var_names[i]);
auto output = op->outputs[0];
PADDLE_ENFORCE_GE(inputs.size(),
1,
......@@ -163,7 +171,6 @@ void CPUQuantizePass::QuantizeInputs(Graph* g,
// create a quantize op desc prototype
OpDesc q_desc;
q_desc.SetType("quantize");
std::vector<Node*> quantize_out_nodes(inputs.size());
std::vector<std::string> quantize_out_node_names(inputs.size());
......@@ -171,25 +178,52 @@ void CPUQuantizePass::QuantizeInputs(Graph* g,
unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX;
float scale = scale_out * max;
for (size_t i = 0; i < inputs.size(); i++) {
// Create quantize output variable
for (size_t var_id = 0; var_id < unique_var_names.size(); var_id++) {
auto index = -1;
for (size_t it = 0; it < inputs.size(); it++) {
if (inputs[it]->Name() == unique_var_names[var_id]) index = it;
}
if (index == -1) {
PADDLE_ENFORCE_NE(index,
-1,
platform::errors::InvalidArgument(
"Var(%s) isn't the input of the %s operator.",
unique_var_names[var_id],
op->Op()->Type()));
}
auto* input = inputs.at(index);
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc);
quantize_out_node_names[i] = quantize_out_nodes[i]->Name();
quantize_out_nodes[var_id] = g->CreateVarNode(&quantize_out_desc);
quantize_out_node_names[var_id] = quantize_out_nodes[var_id]->Name();
q_desc.SetAttr("Scale", scale);
q_desc.SetAttr("Shift", shift);
q_desc.SetInput("Input", std::vector<std::string>({inputs[i]->Name()}));
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node_names[i]}));
q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
q_desc.SetOutput(
"Output", std::vector<std::string>({quantize_out_node_names[var_id]}));
q_desc.SetAttr("is_negative_input", !are_inputs_unsigned);
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
// link quantize op
UnlinkNodes(inputs[i], op);
IR_NODE_LINK_TO(inputs[i], quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
IR_NODE_LINK_TO(quantize_out_nodes[i], op);
UnlinkNodes(input, op);
IR_NODE_LINK_TO(input, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[var_id]);
IR_NODE_LINK_TO(quantize_out_nodes[var_id], op);
}
// If any inputs were duplicated, now you have to enter them in the correct
// order.
for (size_t i = unique_var_names.size(); i < var_names.size(); i++) {
auto index = std::find(
unique_var_names.begin(), unique_var_names.end(), var_names[i]);
if (index != unique_var_names.end()) {
auto id = std::distance(unique_var_names.begin(), index);
quantize_out_node_names[i] = quantize_out_nodes[id]->Name();
IR_NODE_LINK_TO(quantize_out_nodes[id], op);
}
}
// update op's input
......@@ -252,6 +286,8 @@ void CPUQuantizePass::DequantizeOutputs(Graph* g,
bool is_unsigned,
std::string scale_attr_name) const {
auto outputs = op->outputs;
auto var_names = op->Op()->Outputs().at(output_name);
PADDLE_ENFORCE_GE(outputs.size(),
1,
platform::errors::InvalidArgument(
......@@ -259,37 +295,53 @@ void CPUQuantizePass::DequantizeOutputs(Graph* g,
op->Name(),
outputs.size()));
std::vector<std::string> quantize_in_node_names(outputs.size());
std::vector<std::string> dequantize_in_node_names(outputs.size());
std::vector<Node*> dequantize_in_nodes(outputs.size());
unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max;
for (size_t i = 0; i < outputs.size(); i++) {
for (size_t var_id = 0; var_id < var_names.size(); var_id++) {
auto index = -1;
for (size_t it = 0; it < outputs.size(); it++) {
if (outputs[it]->Name() == var_names[var_id]) index = it;
}
if (index == -1) {
PADDLE_ENFORCE_NE(index,
-1,
platform::errors::InvalidArgument(
"Var(%s) isn't the input of the %s operator.",
var_names[var_id],
op->Op()->Type()));
}
auto* output = outputs.at(index);
// Create dequantize input variable
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
Node* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
quantize_in_node_names[i] = dequantize_in_node->Name();
dequantize_in_nodes[var_id] = g->CreateVarNode(&dequantize_in_desc);
dequantize_in_node_names[var_id] = dequantize_in_nodes[var_id]->Name();
// create a dequantize op node for output.
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({quantize_in_node_names[i]}));
deq_desc.SetOutput("Output",
std::vector<std::string>({outputs[i]->Name()}));
deq_desc.SetInput(
"Input", std::vector<std::string>({dequantize_in_node_names[var_id]}));
deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
deq_desc.SetAttr("Scale", scale);
deq_desc.SetAttr("is_negative_input", !is_unsigned);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
// link dequantize op
UnlinkNodes(op, outputs[i]);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, outputs[i]);
UnlinkNodes(op, output);
IR_NODE_LINK_TO(op, dequantize_in_nodes[var_id]);
IR_NODE_LINK_TO(dequantize_in_nodes[var_id], dequantize_op);
IR_NODE_LINK_TO(dequantize_op, output);
}
// update op's output
op->Op()->SetOutput(output_name, quantize_in_node_names);
op->Op()->SetOutput(output_name, dequantize_in_node_names);
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}
......
......@@ -881,6 +881,45 @@ TEST(CpuQuantizePass, multi_gru_3) {
MainTestMultiGru(layers);
}
static const std::initializer_list<std::string>
variable_names_multi_inputs_outputs = {"a", "b", "c1", "c2", "d", "e"};
// a->Pool->b
// b->Split->c1, c2
// (c1, c2, c1, c2)->Concat->d
// d->Pool->e
ProgramDesc BuildProgramDescMulti() {
ProgramDesc prog;
for (auto& v : variable_names_multi_inputs_outputs) {
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}
SetOp(&prog, "pool2d", "Pool", {"a"}, {"b"}, true, "float32");
SetOp(&prog, "split", "Split", {"b"}, {"c1", "c2"}, true, "int8");
SetOp(
&prog, "concat", "Concat", {"c1", "c2", "c1", "c2"}, {"d"}, true, "int8");
SetOp(&prog, "pool2d", "Pool2", {"d"}, {"e"}, true, "float32");
return prog;
}
TEST(CpuQuantizePass, multi_inputs_outputs_ops) {
// a->QUANT1->Split
// b1->DEQUANT->OUT->QUANT
// b2->DEQUANT->OUT->QUANT
// (b1, b2, b1, b2)->Concat->c->DEQUANT->Pool->d
int added_nodes = 6 * 2;
std::unordered_map<std::string, int> expected_operators = {{"pool2d", 2},
{"concat", 1},
{"split", 1},
{"quantize", 3},
{"dequantize", 3}};
MainTest(BuildProgramDescMulti(),
variable_names_multi_inputs_outputs,
expected_operators,
added_nodes);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -158,6 +158,11 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Scale"));
float dequant_shift = dequant_op->Op()->GetAttrIfExists<float>("Shift");
float quant_shift = quant_op->Op()->GetAttrIfExists<float>("Shift");
if (quant_op->Op()->GetAttrIfExists<bool>("is_negative_input") !=
dequant_op->Op()->GetAttrIfExists<bool>("is_negative_input")) {
return;
}
PADDLE_ENFORCE_NE(
nodes_keep_counter->find(dequant_out),
nodes_keep_counter->end(),
......@@ -169,14 +174,13 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
if (dequant_scale == quant_scale && dequant_shift == quant_shift) {
// squash dequantize-quantize to nothing
auto quant_out_var_name = quant_out->Name();
auto next_op_inputs = next_op_desc->InputNames();
for (const auto& name : next_op_inputs) {
auto input_names = next_op_desc->Input(name);
for (auto input_name : next_op_desc->InputNames()) {
auto& input_names = next_op_desc->MutableInputs()->at(input_name);
std::replace(input_names.begin(),
input_names.end(),
quant_out_var_name,
dequant_in->Name());
next_op_desc->SetInput(name, input_names);
next_op_desc->SetInput(input_name, input_names);
}
if (keep_dequant)
......@@ -413,12 +417,11 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
// update the next operator input,
// by replacing quant_out with first_quant_out
auto last_op_names = last_op->Op()->Input(last_op_input_name);
last_op_names.erase(
std::remove(
last_op_names.begin(), last_op_names.end(), quant_out->Name()),
last_op_names.end());
last_op_names.push_back(first_quant_out->Name());
auto last_op_names = last_op->Op()->Inputs().at(last_op_input_name);
std::replace(last_op_names.begin(),
last_op_names.end(),
quant_out->Name(),
first_quant_out->Name());
last_op_op->SetInput(last_op_input_name,
std::vector<std::string>(last_op_names));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册