未验证 提交 2def79bc 编写于 作者: Z Zuza 提交者: GitHub

Quantize elementwise mul (#40546)

* Quantize elementwise mul op

* Parametrize elementwise functions

* Fix code formatting
上级 dec2b1ca
...@@ -2052,18 +2052,19 @@ PDNode *patterns::Pool::operator()() { ...@@ -2052,18 +2052,19 @@ PDNode *patterns::Pool::operator()() {
return output_var; return output_var;
} }
PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr()) const std::string elementwise_type) {
->assert_is_op("elementwise_add"); auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
x_var->AsInput()->assert_is_op_input("elementwise_add", "X");
y_var->AsInput()->assert_is_op_input("elementwise_add", "Y"); x_var->AsInput()->assert_is_op_input(elementwise_type, "X");
auto out_var = pattern->NewNode(elementwise_add_out_repr()) y_var->AsInput()->assert_is_op_input(elementwise_type, "Y");
auto out_var = pattern->NewNode(elementwise_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("elementwise_add", "Out"); ->assert_is_op_output(elementwise_type, "Out");
elementwise_add_op->LinksFrom({x_var, y_var}); elementwise_op->LinksFrom({x_var, y_var});
elementwise_add_op->LinksTo({out_var}); elementwise_op->LinksTo({out_var});
return out_var; return out_var;
} }
......
...@@ -1016,20 +1016,20 @@ struct Pool : public PatternBase { ...@@ -1016,20 +1016,20 @@ struct Pool : public PatternBase {
PATTERN_DECL_NODE(pool_output); PATTERN_DECL_NODE(pool_output);
}; };
// ElementwiseAdd used in residual connections. // Elementwise ops
// y_var is used and convolution output. // Forward pass for element-wise operators (add, mul)
// The operator is removed, when residual // elementwise_mul_out is the result of the operator
// connection fusion is on. struct Elementwise : public PatternBase {
struct ElementwiseAdd : public PatternBase { Elementwise(PDPattern* pattern, const std::string& name_scope)
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "elementwise") {}
: PatternBase(pattern, name_scope, "elementwise_add") {}
PDNode* operator()(PDNode* x_var, PDNode* y_var,
PDNode* operator()(PDNode* x_var, PDNode* y_var); const std::string elementwise_type);
PATTERN_DECL_NODE(elementwise_add_op); PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_add_x); PATTERN_DECL_NODE(elementwise_x);
PATTERN_DECL_NODE(elementwise_add_y); PATTERN_DECL_NODE(elementwise_y);
PATTERN_DECL_NODE(elementwise_add_out); PATTERN_DECL_NODE(elementwise_out);
}; };
// Transpose op // Transpose op
......
...@@ -145,10 +145,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -145,10 +145,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
patterns::Conv conv_pattern{pattern, name_scope}; patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern(); auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_add_pattern( elementwise_pattern(
conv_output, conv_output, pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); "elementwise_add");
conv_output->AsIntermediate(); conv_output->AsIntermediate();
int found_conv_as_x_count = 0; int found_conv_as_x_count = 0;
...@@ -160,16 +160,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -160,16 +160,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_identity, elementwise_add_y, GET_IR_NODE_FROM_SUBGRAPH(elementwise_identity, elementwise_y,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, elementwise_add_identity, conv_output)) return; if (!IsReachable(g, elementwise_identity, conv_output)) return;
if (HasFusedActivation(conv_op)) return; if (HasFusedActivation(conv_op)) return;
...@@ -179,14 +179,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -179,14 +179,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
return; return;
} }
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()}); conv_op->Op()->SetInput("ResidualData", {elementwise_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true); conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op}); GraphSafeRemoveNodes(g, {conv_output, elementwise_op});
IR_NODE_LINK_TO(elementwise_add_identity, conv_op); IR_NODE_LINK_TO(elementwise_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out); IR_NODE_LINK_TO(conv_op, elementwise_out);
found_conv_as_x_count++; found_conv_as_x_count++;
}; };
...@@ -212,10 +212,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -212,10 +212,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
patterns::Conv conv_pattern{pattern, name_scope}; patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern(); auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_add_pattern( elementwise_pattern(
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()), pattern->NewNode(elementwise_pattern.elementwise_x_repr()), conv_output,
conv_output); "elementwise_add");
conv_output->AsIntermediate(); conv_output->AsIntermediate();
int found_conv_as_y_count = 0; int found_conv_as_y_count = 0;
...@@ -227,16 +227,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -227,16 +227,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, elementwise_add_x, conv_output)) return; if (!IsReachable(g, elementwise_x, conv_output)) return;
if (HasFusedActivation(conv_op)) return; if (HasFusedActivation(conv_op)) return;
...@@ -246,14 +246,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -246,14 +246,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
return; return;
} }
conv_op->Op()->SetInput("ResidualData", {elementwise_add_x->Name()}); conv_op->Op()->SetInput("ResidualData", {elementwise_x->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true); conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op}); GraphSafeRemoveNodes(g, {conv_output, elementwise_op});
IR_NODE_LINK_TO(elementwise_add_x, conv_op); IR_NODE_LINK_TO(elementwise_x, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out); IR_NODE_LINK_TO(conv_op, elementwise_out);
found_conv_as_y_count++; found_conv_as_y_count++;
}; };
...@@ -282,8 +282,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -282,8 +282,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
patterns::Conv conv_y_pattern{pattern, name_scope}; patterns::Conv conv_y_pattern{pattern, name_scope};
auto conv_y_output = conv_y_pattern(); auto conv_y_output = conv_y_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_add_pattern(conv_x_output, conv_y_output); elementwise_pattern(conv_x_output, conv_y_output, "elementwise_add");
conv_x_output->AsIntermediate(); conv_x_output->AsIntermediate();
conv_y_output->AsIntermediate(); conv_y_output->AsIntermediate();
...@@ -301,10 +301,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -301,10 +301,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
GET_IR_NODE_FROM_SUBGRAPH(conv_y_filter, conv_filter, conv_y_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_y_filter, conv_filter, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_y_output, conv_output, conv_y_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_y_output, conv_output, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) LOG(WARNING)
...@@ -312,8 +312,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -312,8 +312,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
return; return;
} }
if (FindFuseOption(*conv_x_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_x_op, *elementwise_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_y_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_y_op, *elementwise_op) != FUSE_MKLDNN) return;
Node* projection_node; Node* projection_node;
Node* residual_conv_op; Node* residual_conv_op;
...@@ -333,14 +333,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -333,14 +333,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
if (HasFusedActivation(residual_conv_op)) return; if (HasFusedActivation(residual_conv_op)) return;
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()}); residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); residual_conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
residual_conv_op->Op()->SetAttr("fuse_residual_connection", true); residual_conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {residual_conv_output, elementwise_add_op}); GraphSafeRemoveNodes(g, {residual_conv_output, elementwise_op});
IR_NODE_LINK_TO(projection_node, residual_conv_op); IR_NODE_LINK_TO(projection_node, residual_conv_op);
IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out); IR_NODE_LINK_TO(residual_conv_op, elementwise_out);
found_projection_conv_count++; found_projection_conv_count++;
}; };
......
...@@ -807,74 +807,74 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -807,74 +807,74 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
PrettyLogDetail("--- quantized %d matmul ops", quantize_matmul_count); PrettyLogDetail("--- quantized %d matmul ops", quantize_matmul_count);
} }
void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const { void CPUQuantizePass::QuantizeElementwise(
Graph* graph, const std::string elementwise_type) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; patterns::Elementwise elementwise_pattern{pattern, name_scope_};
elementwise_add_pattern( elementwise_pattern(
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()), pattern->NewNode(elementwise_pattern.elementwise_x_repr()),
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
elementwise_type);
int quantize_elementwise_add_count = 0; int quantize_elementwise_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "Quantize elementwise_add op"; VLOG(4) << "Quantize " + elementwise_type + " op";
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
// skip if should not be quantized // skip if should not be quantized
if (!platform::HasOpINT8DataType(elementwise_add_op->Op())) { if (!platform::HasOpINT8DataType(elementwise_op->Op())) {
LogQuantizationDisabled(elementwise_add_op); LogQuantizationDisabled(elementwise_op);
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (!AreScalesPresentForNodes( if (!AreScalesPresentForNodes(
{elementwise_add_x, elementwise_add_y, elementwise_add_out})) { {elementwise_x, elementwise_y, elementwise_out})) {
LogCannotQuantizeOp(elementwise_add_op, LogCannotQuantizeOp(elementwise_op,
"No scale available for the operator"); "No scale available for the operator");
return; return;
} }
bool is_x_unsigned{false}, is_y_unsigned{false}; bool is_x_unsigned{false}, is_y_unsigned{false};
auto input_x_scale = auto input_x_scale = GetScaleValueForNode(elementwise_x, &is_x_unsigned);
GetScaleValueForNode(elementwise_add_x, &is_x_unsigned); auto input_y_scale = GetScaleValueForNode(elementwise_y, &is_y_unsigned);
auto input_y_scale =
GetScaleValueForNode(elementwise_add_y, &is_y_unsigned);
// TODO(sfraczek): add support for different signness // TODO(sfraczek): add support for different signness
if (is_x_unsigned != is_y_unsigned) { if (is_x_unsigned != is_y_unsigned) {
LogCannotQuantizeOp(elementwise_add_op, LogCannotQuantizeOp(elementwise_op,
"ElementwiseAdd inputs must be of the same type."); "Elementwise inputs must be of the same type.");
return; return;
} }
QuantizeInput(g, elementwise_add_op, elementwise_add_x, "X", input_x_scale, QuantizeInput(g, elementwise_op, elementwise_x, "X", input_x_scale,
is_x_unsigned, "Scale_x"); is_x_unsigned, "Scale_x");
QuantizeInput(g, elementwise_add_op, elementwise_add_y, "Y", input_y_scale, QuantizeInput(g, elementwise_op, elementwise_y, "Y", input_y_scale,
is_y_unsigned, "Scale_y"); is_y_unsigned, "Scale_y");
bool is_output_unsigned{false}; bool is_output_unsigned{false};
auto output_scale = auto output_scale =
GetScaleValueForNode(elementwise_add_out, &is_output_unsigned); GetScaleValueForNode(elementwise_out, &is_output_unsigned);
DequantizeOutput(g, elementwise_add_op, elementwise_add_out, "Out", DequantizeOutput(g, elementwise_op, elementwise_out, "Out", output_scale,
output_scale, is_output_unsigned, "Scale_out"); is_output_unsigned, "Scale_out");
++quantize_elementwise_add_count; ++quantize_elementwise_count;
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_elementwise_add_count); AddStatis(quantize_elementwise_count);
PrettyLogDetail("--- quantized %d elementwise_add ops", PrettyLogDetail("--- quantized %d %s ops", quantize_elementwise_count,
quantize_elementwise_add_count); elementwise_type);
} }
void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const { void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
...@@ -1146,7 +1146,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -1146,7 +1146,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeFc(graph); QuantizeFc(graph);
QuantizeReshape(graph); QuantizeReshape(graph);
QuantizeMatmul(graph); QuantizeMatmul(graph);
QuantizeElementwiseAdd(graph); QuantizeElementwise(graph, "elementwise_add");
QuantizeElementwise(graph, "elementwise_mul");
QuantizeFusionGru(graph); QuantizeFusionGru(graph);
QuantizeMultiGru(graph); QuantizeMultiGru(graph);
QuantizeFusionLSTM(graph); QuantizeFusionLSTM(graph);
......
...@@ -57,7 +57,8 @@ class CPUQuantizePass : public FusePassBase { ...@@ -57,7 +57,8 @@ class CPUQuantizePass : public FusePassBase {
void QuantizeTranspose(Graph* graph) const; void QuantizeTranspose(Graph* graph) const;
void QuantizeReshape(Graph* graph) const; void QuantizeReshape(Graph* graph) const;
void QuantizeMatmul(Graph* graph) const; void QuantizeMatmul(Graph* graph) const;
void QuantizeElementwiseAdd(Graph* graph) const; void QuantizeElementwise(Graph* graph,
const std::string elementwise_type) const;
void QuantizeFusionGru(Graph* graph) const; void QuantizeFusionGru(Graph* graph) const;
void QuantizeMultiGru(Graph* graph) const; void QuantizeMultiGru(Graph* graph) const;
void QuantizeFusionLSTM(Graph* graph) const; void QuantizeFusionLSTM(Graph* graph) const;
......
...@@ -90,7 +90,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -90,7 +90,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("Scale_x", 1.0f); op->SetAttr("Scale_x", 1.0f);
op->SetAttr("Scale_y", 1.0f); op->SetAttr("Scale_y", 1.0f);
op->SetAttr("Scale_out", 1.0f); op->SetAttr("Scale_out", 1.0f);
} else if (type == "elementwise_add") { } else if (type == "elementwise_add" || type == "elementwise_mul") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
...@@ -167,7 +167,8 @@ void CheckScales(const OpDesc* op, float scale, float shift) { ...@@ -167,7 +167,8 @@ void CheckScales(const OpDesc* op, float scale, float shift) {
scale); scale);
scale_names.push_back("Scale_in"); scale_names.push_back("Scale_in");
scale_names.push_back("Scale_out"); scale_names.push_back("Scale_out");
} else if (type == "matmul" || type == "elementwise_add") { } else if (type == "matmul" || type == "elementwise_add" ||
type == "elementwise_mul") {
scale_names.push_back("Scale_x"); scale_names.push_back("Scale_x");
scale_names.push_back("Scale_y"); scale_names.push_back("Scale_y");
scale_names.push_back("Scale_out"); scale_names.push_back("Scale_out");
...@@ -546,46 +547,77 @@ TEST(CpuQuantizePass, matmul_not_quantized) { ...@@ -546,46 +547,77 @@ TEST(CpuQuantizePass, matmul_not_quantized) {
expected_operators, added_nodes, 1.0f); expected_operators, added_nodes, 1.0f);
} }
static const std::initializer_list<std::string> variable_names_elementwise_add = static const std::initializer_list<std::string> variable_names_elementwise = {
{"a", "b", "c", "d", "e", "f"}; "a", "b", "c", "d", "e", "f"};
ProgramDesc BuildProgramDescElementwiseAdd() { ProgramDesc BuildProgramDescElementwise(const std::string elementwise_type,
const std::string elementwise_name) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names_elementwise_add) { for (auto& v : variable_names_elementwise) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true); SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true); SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true);
SetOp(&prog, "elementwise_add", "ElementwiseAdd", {"b", "d"}, {"e"}, true, SetOp(&prog, elementwise_type, elementwise_name, {"b", "d"}, {"e"}, true,
"int8"); "int8");
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32"); SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32");
return prog; return prog;
} }
TEST(CpuQuantizePass, elementwise_add) { void TestElementwise(const std::string elementwise_type,
const std::string elementwise_name) {
// 2 Quant + 2 IN + 1 DeQuant + 1 OUT // 2 Quant + 2 IN + 1 DeQuant + 1 OUT
int added_nodes = 6; int added_nodes = 6;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"elementwise_add", 1}, {"quantize", 2}, {"dequantize", 3}}; {elementwise_type, 1}, {"quantize", 2}, {"dequantize", 3}};
MainTest(BuildProgramDescElementwiseAdd(), variable_names_elementwise_add, MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
expected_operators, added_nodes, SCALE * S8_MAX); variable_names_elementwise, expected_operators, added_nodes,
SCALE * S8_MAX);
} }
TEST(CpuQuantizePass, elementwise_add_output_scale_missing) { void TestElementwiseOutputScaleMissing(const std::string elementwise_type,
const std::string elementwise_name) {
int added_nodes = 0; int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"elementwise_add", 1}, {"quantize", 0}, {"dequantize", 2}}; {elementwise_type, 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwiseAdd(), variable_names_elementwise_add, MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
expected_operators, added_nodes, 1.f, 1.f, "e"); variable_names_elementwise, expected_operators, added_nodes, 1.f,
1.f, "e");
} }
TEST(CpuQuantizePass, elementwise_add_unsigned_and_signed_input) { void TestElementwiseUnsignedAndSignedInput(const std::string elementwise_type,
const std::string elementwise_name) {
int added_nodes = 0; int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"elementwise_add", 1}, {"quantize", 0}, {"dequantize", 2}}; {elementwise_type, 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwiseAdd(), variable_names_elementwise_add, MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
expected_operators, added_nodes, 1.f, 1.f, "", "b"); variable_names_elementwise, expected_operators, added_nodes, 1.f,
1.f, "", "b");
}
TEST(CpuQuantizePass, elementwise_add) {
TestElementwise("elementwise_add", "ElementwiseAdd");
}
TEST(CpuQuantizePass, elementwise_add_output_scale_missing) {
TestElementwiseOutputScaleMissing("elementwise_add", "ElementwiseAdd");
}
TEST(CpuQuantizePass, elementwise_add_unsigned_and_signed_input) {
TestElementwiseUnsignedAndSignedInput("elementwise_add", "ElementwiseAdd");
}
TEST(CpuQuantizePass, elementwise_mul) {
TestElementwise("elementwise_mul", "ElementwiseMul");
}
TEST(CpuQuantizePass, elementwise_mul_output_scale_missing) {
TestElementwiseOutputScaleMissing("elementwise_mul", "ElementwiseMul");
}
TEST(CpuQuantizePass, elementwise_mul_unsigned_and_signed_input) {
TestElementwiseUnsignedAndSignedInput("elementwise_mul", "ElementwiseMul");
} }
const std::vector<std::string> churn_out_vars(ProgramDesc* prog, const std::vector<std::string> churn_out_vars(ProgramDesc* prog,
......
...@@ -26,10 +26,10 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -26,10 +26,10 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Marks operators which are to be quantized."; VLOG(3) << "Marks operators which are to be quantized.";
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>( std::unordered_set<std::string>(
{"concat", "conv2d", "depthwise_conv2d", "elementwise_add", "fc", {"concat", "conv2d", "depthwise_conv2d", "elementwise_add",
"matmul", "nearest_interp", "nearest_interp_v2", "pool2d", "elementwise_mul", "fc", "matmul", "nearest_interp",
"prior_box", "reshape2", "transpose2", "fusion_gru", "fusion_lstm", "nearest_interp_v2", "pool2d", "prior_box", "reshape2", "transpose2",
"multi_gru", "slice"}); "fusion_gru", "fusion_lstm", "multi_gru", "slice"});
const auto& excluded_ids_list = const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids"); Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list = const auto& op_types_list =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册