未验证 提交 e0d7d790 编写于 作者: Z Zuza Gawrysiak 提交者: GitHub

Refactor quantization of immutable ops (#43973)

* Refactor quantization of immutable ops

* Fix code formatting

* Fix formatting

* Specify input names

* Fix formatting

* Change string to reference

* Formatting
上级 59813de9
...@@ -1802,80 +1802,23 @@ PDNode *patterns::Conv::operator()() { ...@@ -1802,80 +1802,23 @@ PDNode *patterns::Conv::operator()() {
return output_var; return output_var;
} }
PDNode *patterns::Transpose::operator()() { PDNode *patterns::Immutable::operator()(const std::string &immutable_type,
const std::string &input_name) {
auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
auto transpose_op = auto immutable_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); pattern->NewNode(immutable_op_repr())->assert_is_op(immutable_type);
auto transpose_in = pattern->NewNode(transpose_in_repr()) auto immutable_in = pattern->NewNode(immutable_in_repr())
->AsInput() ->AsInput()
->assert_is_op_input("transpose2"); ->assert_is_op_input(immutable_type, input_name);
auto transpose_out = pattern->NewNode(transpose_out_repr()) auto immutable_out = pattern->NewNode(immutable_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("transpose2", "Out"); ->assert_is_op_output(immutable_type, "Out");
prev_op->LinksTo({transpose_in}); prev_op->LinksTo({immutable_in});
transpose_op->LinksFrom({transpose_in}).LinksTo({transpose_out}); immutable_op->LinksFrom({immutable_in}).LinksTo({immutable_out});
return transpose_out; return immutable_out;
}
PDNode *patterns::Reshape::operator()() {
auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
auto reshape_in = pattern->NewNode(reshape_in_repr())
->AsInput()
->assert_is_op_input("reshape2", "X");
auto reshape_out = pattern->NewNode(reshape_out_repr())
->AsOutput()
->assert_is_op_output("reshape2", "Out");
prev_op->LinksTo({reshape_in});
reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out});
return reshape_out;
}
PDNode *patterns::Slice::operator()() {
auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
auto slice_op = pattern->NewNode(slice_op_repr())->assert_is_op("slice");
auto slice_in = pattern->NewNode(slice_in_repr())
->AsInput()
->assert_is_op_input("slice", "Input");
auto slice_out = pattern->NewNode(slice_out_repr())
->AsOutput()
->assert_is_op_output("slice", "Out");
prev_op->LinksTo({slice_in});
slice_op->LinksFrom({slice_in}).LinksTo({slice_out});
return slice_out;
}
PDNode *patterns::NearestInterp::operator()() {
auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
auto nearest_interp_op =
pattern->NewNode(nearest_interp_op_repr())
->assert_is_ops({"nearest_interp", "nearest_interp_v2"});
auto nearest_interp_in =
pattern->NewNode(nearest_interp_in_repr())
->AsInput()
->assert_is_ops_input({"nearest_interp", "nearest_interp_v2"}, "X");
auto nearest_interp_out =
pattern->NewNode(nearest_interp_out_repr())
->AsOutput()
->assert_is_ops_output({"nearest_interp", "nearest_interp_v2"},
"Out");
prev_op->LinksTo({nearest_interp_in});
nearest_interp_op->LinksFrom({nearest_interp_in})
.LinksTo({nearest_interp_out});
return nearest_interp_out;
} }
PDNode *patterns::Matmul::operator()() { PDNode *patterns::Matmul::operator()() {
...@@ -2118,7 +2061,7 @@ PDNode *patterns::Pool::operator()() { ...@@ -2118,7 +2061,7 @@ PDNode *patterns::Pool::operator()() {
PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *patterns::Elementwise::operator()(PDNode *x_var,
PDNode *y_var, PDNode *y_var,
const std::string elementwise_type) { const std::string &elementwise_type) {
auto elementwise_op = auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
...@@ -2135,7 +2078,7 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, ...@@ -2135,7 +2078,7 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var,
} }
PDNode *patterns::ElementwiseOp::operator()( PDNode *patterns::ElementwiseOp::operator()(
const std::string elementwise_type) { const std::string &elementwise_type) {
auto elementwise_op = auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
...@@ -2151,7 +2094,7 @@ PDNode *patterns::ElementwiseOp::operator()( ...@@ -2151,7 +2094,7 @@ PDNode *patterns::ElementwiseOp::operator()(
PDNode *patterns::ResidualElementwise::operator()( PDNode *patterns::ResidualElementwise::operator()(
PDNode *op_var, PDNode *op_var,
PDNode *residual_var, PDNode *residual_var,
const std::string elementwise_type, const std::string &elementwise_type,
bool as_x) { bool as_x) {
auto elementwise_op = auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
......
...@@ -1087,7 +1087,7 @@ struct Elementwise : public PatternBase { ...@@ -1087,7 +1087,7 @@ struct Elementwise : public PatternBase {
PDNode* operator()(PDNode* x_var, PDNode* operator()(PDNode* x_var,
PDNode* y_var, PDNode* y_var,
const std::string elementwise_type); const std::string& elementwise_type);
PATTERN_DECL_NODE(elementwise_op); PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_x); PATTERN_DECL_NODE(elementwise_x);
...@@ -1102,7 +1102,7 @@ struct ElementwiseOp : public PatternBase { ...@@ -1102,7 +1102,7 @@ struct ElementwiseOp : public PatternBase {
ElementwiseOp(PDPattern* pattern, const std::string& name_scope) ElementwiseOp(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise") {} : PatternBase(pattern, name_scope, "elementwise") {}
PDNode* operator()(const std::string elementwise_type); PDNode* operator()(const std::string& elementwise_type);
PATTERN_DECL_NODE(elementwise_op); PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_out); PATTERN_DECL_NODE(elementwise_out);
...@@ -1118,7 +1118,7 @@ struct ResidualElementwise : public PatternBase { ...@@ -1118,7 +1118,7 @@ struct ResidualElementwise : public PatternBase {
: PatternBase(pattern, name_scope, "residual_elementwise") {} : PatternBase(pattern, name_scope, "residual_elementwise") {}
PDNode* operator()(PDNode* op_var, PDNode* operator()(PDNode* op_var,
PDNode* residual_var, PDNode* residual_var,
const std::string elementwise_type, const std::string& elementwise_type,
bool as_x); bool as_x);
PATTERN_DECL_NODE(operator_output); PATTERN_DECL_NODE(operator_output);
...@@ -1127,59 +1127,20 @@ struct ResidualElementwise : public PatternBase { ...@@ -1127,59 +1127,20 @@ struct ResidualElementwise : public PatternBase {
PATTERN_DECL_NODE(elementwise_out); PATTERN_DECL_NODE(elementwise_out);
}; };
// Transpose op // General struct for immutable ops:
// Forward pass for transpose. // reshape, transpose, slice, nearest-interp
// transpose_out is a result of the operator. // Forward pass for no weights-op.
struct Transpose : public PatternBase { // immutable_out is a result of the operator.
Transpose(PDPattern* pattern, const std::string& name_scope) struct Immutable : public PatternBase {
: PatternBase(pattern, name_scope, "transpose2") {} Immutable(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "immutable") {}
PDNode* operator()(); PDNode* operator()(const std::string& immutable_type,
PATTERN_DECL_NODE(prev_op); const std::string& input_name);
PATTERN_DECL_NODE(transpose_in);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
};
// Reshape op
// Forward pass for reshape.
// reshape_out is a result of the operator.
struct Reshape : public PatternBase {
Reshape(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "reshape2") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(reshape_in);
PATTERN_DECL_NODE(reshape_op);
PATTERN_DECL_NODE(reshape_out);
};
// Slice op
// Forward pass for slice.
// slice_out is a result of the operator.
struct Slice : public PatternBase {
Slice(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "slice") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(slice_in);
PATTERN_DECL_NODE(slice_op);
PATTERN_DECL_NODE(slice_out);
};
// Nearest Interp op
// Forward pass for nearest_interp.
// nearest_interp_out is a result of the operator.
struct NearestInterp : public PatternBase {
NearestInterp(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "nearest_interp") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op); PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(nearest_interp_in); PATTERN_DECL_NODE(immutable_in);
PATTERN_DECL_NODE(nearest_interp_op); PATTERN_DECL_NODE(immutable_op);
PATTERN_DECL_NODE(nearest_interp_out); PATTERN_DECL_NODE(immutable_out);
}; };
// Matmul op // Matmul op
......
...@@ -669,165 +669,68 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const { ...@@ -669,165 +669,68 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
LogQuantizedOpsCounter("prior_box", quantize_prior_box_count); LogQuantizedOpsCounter("prior_box", quantize_prior_box_count);
} }
void CPUQuantizePass::QuantizeTranspose(Graph* graph) const { void CPUQuantizePass::QuantizeImmutable(Graph* graph,
const std::string& immutable_type,
const std::string& input_name) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::Transpose transpose_pattern{pattern, name_scope_}; patterns::Immutable immutable_pattern{pattern, name_scope_};
transpose_pattern(); immutable_pattern(immutable_type, input_name);
int quantize_transpose_count = 0; int quantize_immutable_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "Quantize transpose op"; VLOG(4) << "Quantize " + immutable_type + " op";
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, transpose_pattern); GET_IR_NODE_FROM_SUBGRAPH(immutable_op, immutable_op, immutable_pattern);
// skip if should not be quantized // skip if should not be quantized
if (!platform::HasOpINT8DataType(transpose_op->Op())) { if (!platform::HasOpINT8DataType(immutable_op->Op())) {
LogQuantizationDisabled(transpose_op); LogQuantizationDisabled(immutable_op);
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, transpose_pattern); GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, immutable_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose_in, transpose_in, transpose_pattern); GET_IR_NODE_FROM_SUBGRAPH(immutable_in, immutable_in, immutable_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, transpose_pattern); GET_IR_NODE_FROM_SUBGRAPH(immutable_out, immutable_out, immutable_pattern);
// skip if prev op and next op is not quantized // skip if prev op and next op is not quantized
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(transpose_out))) { if (!IsOpDequantized(prev_op) && !IsOpQuantized(immutable_out)) {
MarkAndLogCannotQuantizeOp(transpose_op, MarkAndLogCannotQuantizeOp(immutable_op,
"No other quantizable operators nearby"); "No other quantizable operators nearby");
return; return;
} }
if (!AreScalesPresentForNodes({transpose_in, transpose_out})) { if (!AreScalesPresentForNodes({immutable_out})) {
MarkAndLogCannotQuantizeOp(transpose_op, MarkAndLogCannotQuantizeOp(immutable_op,
"No scale available for the operator"); "No scale available for the operator");
return; return;
} }
bool is_input_unsigned{false}; bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(transpose_in, &is_input_unsigned); auto input_scale = GetScaleValueForNode(immutable_out, &is_input_unsigned);
QuantizeInput(
g, transpose_op, transpose_in, "X", input_scale, is_input_unsigned); QuantizeInput(g,
immutable_op,
immutable_in,
input_name,
input_scale,
is_input_unsigned);
bool is_output_unsigned{false}; bool is_output_unsigned{false};
auto output_scale = auto output_scale =
GetScaleValueForNode(transpose_out, &is_output_unsigned); GetScaleValueForNode(immutable_out, &is_output_unsigned);
DequantizeOutput(g, DequantizeOutput(g,
transpose_op, immutable_op,
transpose_out, immutable_out,
"Out", "Out",
output_scale, output_scale,
is_output_unsigned); is_output_unsigned);
++quantize_transpose_count; ++quantize_immutable_count;
};
gpd(graph, handler);
AddStatis(quantize_transpose_count);
LogQuantizedOpsCounter("transpose2", quantize_transpose_count);
}
void CPUQuantizePass::QuantizeReshape(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Reshape reshape_pattern{pattern, name_scope_};
reshape_pattern();
int quantize_reshape_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize reshape op";
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, reshape_pattern);
// skip if should not be quantized
if (!platform::HasOpINT8DataType(reshape_op->Op())) {
LogQuantizationDisabled(reshape_op);
return;
}
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, reshape_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, reshape_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, reshape_pattern);
// skip if prev op is not quantized
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(reshape_out))) {
MarkAndLogCannotQuantizeOp(reshape_op,
"No other quantizable operators nearby");
return;
}
if (!AreScalesPresentForNodes({reshape_in, reshape_out})) {
MarkAndLogCannotQuantizeOp(reshape_op,
"No scale available for the operator");
return;
}
bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(reshape_in, &is_input_unsigned);
QuantizeInput(
g, reshape_op, reshape_in, "X", input_scale, is_input_unsigned);
bool is_output_unsigned{false};
auto output_scale = GetScaleValueForNode(reshape_out, &is_output_unsigned);
DequantizeOutput(
g, reshape_op, reshape_out, "Out", output_scale, is_output_unsigned);
++quantize_reshape_count;
};
gpd(graph, handler);
AddStatis(quantize_reshape_count);
LogQuantizedOpsCounter("reshape2", quantize_reshape_count);
}
void CPUQuantizePass::QuantizeSlice(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Slice slice_pattern{pattern, name_scope_};
slice_pattern();
int quantize_slice_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize slice op";
GET_IR_NODE_FROM_SUBGRAPH(slice_op, slice_op, slice_pattern);
// skip if should not be quantized
if (!platform::HasOpINT8DataType(slice_op->Op())) {
LogQuantizationDisabled(slice_op);
return;
}
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, slice_pattern);
GET_IR_NODE_FROM_SUBGRAPH(slice_in, slice_in, slice_pattern);
GET_IR_NODE_FROM_SUBGRAPH(slice_out, slice_out, slice_pattern);
// skip if prev op and next op is not quantized
if (!IsOpDequantized(prev_op) && !IsOpQuantized(slice_out)) {
MarkAndLogCannotQuantizeOp(slice_op,
"No other quantizable operators nearby");
return;
}
if (!AreScalesPresentForNodes({slice_out})) {
MarkAndLogCannotQuantizeOp(slice_op,
"No scale available for the operator");
return;
}
bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(slice_out, &is_input_unsigned);
QuantizeInput(
g, slice_op, slice_in, "Input", input_scale, is_input_unsigned);
bool is_output_unsigned{false};
auto output_scale = GetScaleValueForNode(slice_out, &is_output_unsigned);
DequantizeOutput(
g, slice_op, slice_out, "Out", output_scale, is_output_unsigned);
++quantize_slice_count;
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_slice_count); AddStatis(quantize_immutable_count);
LogQuantizedOpsCounter("slice", quantize_slice_count); LogQuantizedOpsCounter(immutable_type, quantize_immutable_count);
} }
void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
...@@ -915,7 +818,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -915,7 +818,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
} }
void CPUQuantizePass::QuantizeElementwise( void CPUQuantizePass::QuantizeElementwise(
Graph* graph, const std::string elementwise_type) const { Graph* graph, const std::string& elementwise_type) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_}; patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_};
...@@ -1212,71 +1115,6 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const { ...@@ -1212,71 +1115,6 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
LogQuantizedOpsCounter("fusion_lstm", quantize_count); LogQuantizedOpsCounter("fusion_lstm", quantize_count);
} }
void CPUQuantizePass::QuantizeNearestInterp(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::NearestInterp nearest_interp_pattern{pattern, name_scope_};
nearest_interp_pattern();
int quantize_nearest_interp_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize nearest_interp op";
GET_IR_NODE_FROM_SUBGRAPH(
nearest_interp_op, nearest_interp_op, nearest_interp_pattern);
// skip if should not be quantized
if (!platform::HasOpINT8DataType(nearest_interp_op->Op())) {
LogQuantizationDisabled(nearest_interp_op);
return;
}
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, nearest_interp_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
nearest_interp_in, nearest_interp_in, nearest_interp_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
nearest_interp_out, nearest_interp_out, nearest_interp_pattern);
// skip if prev op and next op is not quantized
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(nearest_interp_out))) {
MarkAndLogCannotQuantizeOp(nearest_interp_op,
"No other quantizable operators nearby");
return;
}
if (!AreScalesPresentForNodes({nearest_interp_in, nearest_interp_out})) {
MarkAndLogCannotQuantizeOp(nearest_interp_op,
"No scale available for the operator");
return;
}
bool is_input_unsigned{false};
auto input_scale =
GetScaleValueForNode(nearest_interp_in, &is_input_unsigned);
QuantizeInput(g,
nearest_interp_op,
nearest_interp_in,
"X",
input_scale,
is_input_unsigned);
bool is_output_unsigned{false};
auto output_scale =
GetScaleValueForNode(nearest_interp_out, &is_output_unsigned);
DequantizeOutput(g,
nearest_interp_op,
nearest_interp_out,
"Out",
output_scale,
is_output_unsigned);
++quantize_nearest_interp_count;
};
gpd(graph, handler);
AddStatis(quantize_nearest_interp_count);
LogQuantizedOpsCounter("nearest_interp", quantize_nearest_interp_count);
}
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph."; VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -1293,18 +1131,19 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -1293,18 +1131,19 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizePool(graph); QuantizePool(graph);
QuantizeConcat(graph); QuantizeConcat(graph);
QuantizePriorBox(graph); QuantizePriorBox(graph);
QuantizeTranspose(graph);
QuantizeFc(graph); QuantizeFc(graph);
QuantizeReshape(graph);
QuantizeMatmul(graph); QuantizeMatmul(graph);
QuantizeImmutable(graph, "reshape2", "X");
QuantizeImmutable(graph, "transpose2", "X");
QuantizeImmutable(graph, "slice", "Input");
QuantizeImmutable(graph, "nearest_interp", "X");
QuantizeImmutable(graph, "nearest_interp_v2", "X");
QuantizeElementwise(graph, "elementwise_add"); QuantizeElementwise(graph, "elementwise_add");
QuantizeElementwise(graph, "elementwise_mul"); QuantizeElementwise(graph, "elementwise_mul");
QuantizeElementwise(graph, "elementwise_sub"); QuantizeElementwise(graph, "elementwise_sub");
QuantizeFusionGru(graph); QuantizeFusionGru(graph);
QuantizeMultiGru(graph); QuantizeMultiGru(graph);
QuantizeFusionLSTM(graph); QuantizeFusionLSTM(graph);
QuantizeSlice(graph);
QuantizeNearestInterp(graph);
} }
} // namespace ir } // namespace ir
......
...@@ -54,16 +54,15 @@ class CPUQuantizePass : public FusePassBase { ...@@ -54,16 +54,15 @@ class CPUQuantizePass : public FusePassBase {
void QuantizePool(Graph* graph) const; void QuantizePool(Graph* graph) const;
void QuantizeConcat(Graph* graph) const; void QuantizeConcat(Graph* graph) const;
void QuantizePriorBox(Graph* graph) const; void QuantizePriorBox(Graph* graph) const;
void QuantizeTranspose(Graph* graph) const;
void QuantizeReshape(Graph* graph) const;
void QuantizeMatmul(Graph* graph) const; void QuantizeMatmul(Graph* graph) const;
void QuantizeElementwise(Graph* graph, void QuantizeElementwise(Graph* graph,
const std::string elementwise_type) const; const std::string& elementwise_type) const;
void QuantizeFusionGru(Graph* graph) const; void QuantizeFusionGru(Graph* graph) const;
void QuantizeMultiGru(Graph* graph) const; void QuantizeMultiGru(Graph* graph) const;
void QuantizeFusionLSTM(Graph* graph) const; void QuantizeFusionLSTM(Graph* graph) const;
void QuantizeSlice(Graph* graph) const; void QuantizeImmutable(Graph* graph,
void QuantizeNearestInterp(Graph* graph) const; const std::string& immutable_type,
const std::string& input_name) const;
void QuantizeInput(Graph* g, void QuantizeInput(Graph* g,
Node* op, Node* op,
......
...@@ -550,55 +550,29 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) { ...@@ -550,55 +550,29 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) {
SCALE * S8_MAX); SCALE * S8_MAX);
} }
TEST(CpuQuantizePass, reshape2) { TestImmutableOp("reshape2"); } const std::vector<std::string> immutables = {
"reshape2", "transpose2", "slice", "nearest_interp", "nearest_interp_v2"};
TEST(CpuQuantizePass, reshape2BetweenNonQuantizedOp) { class TestImmutables : public testing::TestWithParam<std::string> {};
TestImmutableOpBetweenNonQuantizedOp("reshape2");
}
TEST(CpuQuantizePass, reshape2WithManyOutputs) {
TestImmutableOpWithManyOutputs("reshape2");
}
TEST(CpuQuantizePass, transpose2) { TestImmutableOp("transpose2"); }
TEST(CpuQuantizePass, transpose2BetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("transpose2");
}
TEST(CpuQuantizePass, transpose2WithManyOutputs) {
TestImmutableOpWithManyOutputs("transpose2");
}
TEST(CpuQuantizePass, slice) { TestImmutableOp("slice"); }
TEST(CpuQuantizePass, sliceBetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("slice");
}
TEST(CpuQuantizePass, sliceWithManyOutputs) {
TestImmutableOpWithManyOutputs("slice");
}
TEST(CpuQuantizePass, nearestInterp) { TestImmutableOp("nearest_interp"); } TEST_P(TestImmutables, immutable_basic) { TestImmutableOp(GetParam()); }
TEST(CpuQuantizePass, nearestInterpBetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("nearest_interp");
}
TEST(CpuQuantizePass, nearestInterpWithManyOutputs) { TEST_P(TestImmutables, immutable_between_non_quantized) {
TestImmutableOpWithManyOutputs("nearest_interp"); TestImmutableOpBetweenNonQuantizedOp(GetParam());
} }
TEST(CpuQuantizePass, nearestInterpV2) { TestImmutableOp("nearest_interp_v2"); } TEST_P(TestImmutables, immutable_many_outputs) {
TestImmutableOpWithManyOutputs(GetParam());
TEST(CpuQuantizePass, nearestInterpV2BetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("nearest_interp_v2");
} }
TEST(CpuQuantizePass, nearestInterpV2WithManyOutputs) { INSTANTIATE_TEST_CASE_P(
TestImmutableOpWithManyOutputs("nearest_interp_v2"); CpuQuantizePass,
} TestImmutables,
testing::ValuesIn(immutables),
[](const ::testing::TestParamInfo<TestImmutables::ParamType>& info) {
std::string name = info.param;
return name;
});
static const std::initializer_list<std::string> variable_names_matmul = { static const std::initializer_list<std::string> variable_names_matmul = {
"a", "b", "c", "d", "e", "f"}; "a", "b", "c", "d", "e", "f"};
...@@ -735,7 +709,7 @@ TEST_P(TestElementwises, elementwise_unsigned_and_signed_input) { ...@@ -735,7 +709,7 @@ TEST_P(TestElementwises, elementwise_unsigned_and_signed_input) {
} }
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
Elementwises, CpuQuantizePass,
TestElementwises, TestElementwises,
testing::ValuesIn(elementwises), testing::ValuesIn(elementwises),
[](const ::testing::TestParamInfo<TestElementwises::ParamType>& info) { [](const ::testing::TestParamInfo<TestElementwises::ParamType>& info) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册