From e0d7d790112af610a03624602f3361bacde496eb Mon Sep 17 00:00:00 2001 From: Zuza Gawrysiak Date: Tue, 5 Jul 2022 13:08:38 +0200 Subject: [PATCH] Refactor quantization of immutable ops (#43973) * Refactor quantization of immutable ops * Fix code formatting * Fix formatting * Specify input names * Fix formatting * Change string to reference * Formatting --- .../framework/ir/graph_pattern_detector.cc | 85 ++----- .../framework/ir/graph_pattern_detector.h | 69 ++--- .../framework/ir/mkldnn/cpu_quantize_pass.cc | 235 +++--------------- .../framework/ir/mkldnn/cpu_quantize_pass.h | 9 +- .../ir/mkldnn/cpu_quantize_pass_tester.cc | 60 ++--- 5 files changed, 87 insertions(+), 371 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 7ad02fe5ab..154df498e7 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1802,80 +1802,23 @@ PDNode *patterns::Conv::operator()() { return output_var; } -PDNode *patterns::Transpose::operator()() { +PDNode *patterns::Immutable::operator()(const std::string &immutable_type, + const std::string &input_name) { auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); - auto transpose_op = - pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); + auto immutable_op = + pattern->NewNode(immutable_op_repr())->assert_is_op(immutable_type); - auto transpose_in = pattern->NewNode(transpose_in_repr()) + auto immutable_in = pattern->NewNode(immutable_in_repr()) ->AsInput() - ->assert_is_op_input("transpose2"); - auto transpose_out = pattern->NewNode(transpose_out_repr()) + ->assert_is_op_input(immutable_type, input_name); + auto immutable_out = pattern->NewNode(immutable_out_repr()) ->AsOutput() - ->assert_is_op_output("transpose2", "Out"); + ->assert_is_op_output(immutable_type, "Out"); - prev_op->LinksTo({transpose_in}); - transpose_op->LinksFrom({transpose_in}).LinksTo({transpose_out}); - return transpose_out; -} - -PDNode *patterns::Reshape::operator()() { - auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); - - auto reshape_op = - pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2"); - - auto reshape_in = pattern->NewNode(reshape_in_repr()) - ->AsInput() - ->assert_is_op_input("reshape2", "X"); - auto reshape_out = pattern->NewNode(reshape_out_repr()) - ->AsOutput() - ->assert_is_op_output("reshape2", "Out"); - - prev_op->LinksTo({reshape_in}); - reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out}); - return reshape_out; -} - -PDNode *patterns::Slice::operator()() { - auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); - - auto slice_op = pattern->NewNode(slice_op_repr())->assert_is_op("slice"); - - auto slice_in = pattern->NewNode(slice_in_repr()) - ->AsInput() - ->assert_is_op_input("slice", "Input"); - auto slice_out = pattern->NewNode(slice_out_repr()) - ->AsOutput() - ->assert_is_op_output("slice", "Out"); - - prev_op->LinksTo({slice_in}); - slice_op->LinksFrom({slice_in}).LinksTo({slice_out}); - return slice_out; -} - -PDNode *patterns::NearestInterp::operator()() { - auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); - - auto nearest_interp_op = - pattern->NewNode(nearest_interp_op_repr()) - ->assert_is_ops({"nearest_interp", "nearest_interp_v2"}); - - auto nearest_interp_in = - pattern->NewNode(nearest_interp_in_repr()) - ->AsInput() - ->assert_is_ops_input({"nearest_interp", "nearest_interp_v2"}, "X"); - auto nearest_interp_out = - pattern->NewNode(nearest_interp_out_repr()) - ->AsOutput() - ->assert_is_ops_output({"nearest_interp", "nearest_interp_v2"}, - "Out"); - - prev_op->LinksTo({nearest_interp_in}); - nearest_interp_op->LinksFrom({nearest_interp_in}) - .LinksTo({nearest_interp_out}); - return nearest_interp_out; + prev_op->LinksTo({immutable_in}); + immutable_op->LinksFrom({immutable_in}).LinksTo({immutable_out}); + return immutable_out; } PDNode *patterns::Matmul::operator()() { @@ -2118,7 +2061,7 @@ PDNode *patterns::Pool::operator()() { PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var, - const std::string elementwise_type) { + const std::string &elementwise_type) { auto elementwise_op = pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); @@ -2135,7 +2078,7 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, } PDNode *patterns::ElementwiseOp::operator()( - const std::string elementwise_type) { + const std::string &elementwise_type) { auto elementwise_op = pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); @@ -2151,7 +2094,7 @@ PDNode *patterns::ElementwiseOp::operator()( PDNode *patterns::ResidualElementwise::operator()( PDNode *op_var, PDNode *residual_var, - const std::string elementwise_type, + const std::string &elementwise_type, bool as_x) { auto elementwise_op = pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 29d645f6be..be14ef2dbf 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1087,7 +1087,7 @@ struct Elementwise : public PatternBase { PDNode* operator()(PDNode* x_var, PDNode* y_var, - const std::string elementwise_type); + const std::string& elementwise_type); PATTERN_DECL_NODE(elementwise_op); PATTERN_DECL_NODE(elementwise_x); @@ -1102,7 +1102,7 @@ struct ElementwiseOp : public PatternBase { ElementwiseOp(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "elementwise") {} - PDNode* operator()(const std::string elementwise_type); + PDNode* operator()(const std::string& elementwise_type); PATTERN_DECL_NODE(elementwise_op); PATTERN_DECL_NODE(elementwise_out); @@ -1118,7 +1118,7 @@ struct ResidualElementwise : public PatternBase { : PatternBase(pattern, name_scope, "residual_elementwise") {} PDNode* operator()(PDNode* op_var, PDNode* residual_var, - const std::string elementwise_type, + const std::string& elementwise_type, bool as_x); PATTERN_DECL_NODE(operator_output); @@ -1127,59 +1127,20 @@ struct ResidualElementwise : public PatternBase { PATTERN_DECL_NODE(elementwise_out); }; -// Transpose op -// Forward pass for transpose. -// transpose_out is a result of the operator. -struct Transpose : public PatternBase { - Transpose(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "transpose2") {} +// General struct for immutable ops: +// reshape, transpose, slice, nearest-interp +// Forward pass for no weights-op. +// immutable_out is a result of the operator. +struct Immutable : public PatternBase { + Immutable(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "immutable") {} - PDNode* operator()(); - PATTERN_DECL_NODE(prev_op); - PATTERN_DECL_NODE(transpose_in); - PATTERN_DECL_NODE(transpose_op); - PATTERN_DECL_NODE(transpose_out); -}; - -// Reshape op -// Forward pass for reshape. -// reshape_out is a result of the operator. -struct Reshape : public PatternBase { - Reshape(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "reshape2") {} - - PDNode* operator()(); - PATTERN_DECL_NODE(prev_op); - PATTERN_DECL_NODE(reshape_in); - PATTERN_DECL_NODE(reshape_op); - PATTERN_DECL_NODE(reshape_out); -}; -// Slice op -// Forward pass for slice. -// slice_out is a result of the operator. -struct Slice : public PatternBase { - Slice(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "slice") {} - - PDNode* operator()(); - PATTERN_DECL_NODE(prev_op); - PATTERN_DECL_NODE(slice_in); - PATTERN_DECL_NODE(slice_op); - PATTERN_DECL_NODE(slice_out); -}; - -// Nearest Interp op -// Forward pass for nearest_interp. -// nearest_interp_out is a result of the operator. -struct NearestInterp : public PatternBase { - NearestInterp(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "nearest_interp") {} - - PDNode* operator()(); + PDNode* operator()(const std::string& immutable_type, + const std::string& input_name); PATTERN_DECL_NODE(prev_op); - PATTERN_DECL_NODE(nearest_interp_in); - PATTERN_DECL_NODE(nearest_interp_op); - PATTERN_DECL_NODE(nearest_interp_out); + PATTERN_DECL_NODE(immutable_in); + PATTERN_DECL_NODE(immutable_op); + PATTERN_DECL_NODE(immutable_out); }; // Matmul op diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index b0d41c16f5..26a4478fff 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -669,165 +669,68 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const { LogQuantizedOpsCounter("prior_box", quantize_prior_box_count); } -void CPUQuantizePass::QuantizeTranspose(Graph* graph) const { +void CPUQuantizePass::QuantizeImmutable(Graph* graph, + const std::string& immutable_type, + const std::string& input_name) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); - patterns::Transpose transpose_pattern{pattern, name_scope_}; - transpose_pattern(); + patterns::Immutable immutable_pattern{pattern, name_scope_}; + immutable_pattern(immutable_type, input_name); - int quantize_transpose_count = 0; + int quantize_immutable_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "Quantize transpose op"; - GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, transpose_pattern); + VLOG(4) << "Quantize " + immutable_type + " op"; + GET_IR_NODE_FROM_SUBGRAPH(immutable_op, immutable_op, immutable_pattern); // skip if should not be quantized - if (!platform::HasOpINT8DataType(transpose_op->Op())) { - LogQuantizationDisabled(transpose_op); + if (!platform::HasOpINT8DataType(immutable_op->Op())) { + LogQuantizationDisabled(immutable_op); return; } - GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, transpose_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose_in, transpose_in, transpose_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, transpose_pattern); + GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, immutable_pattern); + GET_IR_NODE_FROM_SUBGRAPH(immutable_in, immutable_in, immutable_pattern); + GET_IR_NODE_FROM_SUBGRAPH(immutable_out, immutable_out, immutable_pattern); // skip if prev op and next op is not quantized - if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(transpose_out))) { - MarkAndLogCannotQuantizeOp(transpose_op, + if (!IsOpDequantized(prev_op) && !IsOpQuantized(immutable_out)) { + MarkAndLogCannotQuantizeOp(immutable_op, "No other quantizable operators nearby"); return; } - if (!AreScalesPresentForNodes({transpose_in, transpose_out})) { - MarkAndLogCannotQuantizeOp(transpose_op, + if (!AreScalesPresentForNodes({immutable_out})) { + MarkAndLogCannotQuantizeOp(immutable_op, "No scale available for the operator"); return; } bool is_input_unsigned{false}; - auto input_scale = GetScaleValueForNode(transpose_in, &is_input_unsigned); - QuantizeInput( - g, transpose_op, transpose_in, "X", input_scale, is_input_unsigned); + auto input_scale = GetScaleValueForNode(immutable_out, &is_input_unsigned); + + QuantizeInput(g, + immutable_op, + immutable_in, + input_name, + input_scale, + is_input_unsigned); bool is_output_unsigned{false}; auto output_scale = - GetScaleValueForNode(transpose_out, &is_output_unsigned); + GetScaleValueForNode(immutable_out, &is_output_unsigned); DequantizeOutput(g, - transpose_op, - transpose_out, + immutable_op, + immutable_out, "Out", output_scale, is_output_unsigned); - ++quantize_transpose_count; + ++quantize_immutable_count; }; gpd(graph, handler); - AddStatis(quantize_transpose_count); - LogQuantizedOpsCounter("transpose2", quantize_transpose_count); -} - -void CPUQuantizePass::QuantizeReshape(Graph* graph) const { - GraphPatternDetector gpd; - auto pattern = gpd.mutable_pattern(); - patterns::Reshape reshape_pattern{pattern, name_scope_}; - reshape_pattern(); - - int quantize_reshape_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - VLOG(4) << "Quantize reshape op"; - GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, reshape_pattern); - - // skip if should not be quantized - if (!platform::HasOpINT8DataType(reshape_op->Op())) { - LogQuantizationDisabled(reshape_op); - return; - } - GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, reshape_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, reshape_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, reshape_pattern); - - // skip if prev op is not quantized - if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(reshape_out))) { - MarkAndLogCannotQuantizeOp(reshape_op, - "No other quantizable operators nearby"); - return; - } - - if (!AreScalesPresentForNodes({reshape_in, reshape_out})) { - MarkAndLogCannotQuantizeOp(reshape_op, - "No scale available for the operator"); - return; - } - - bool is_input_unsigned{false}; - auto input_scale = GetScaleValueForNode(reshape_in, &is_input_unsigned); - QuantizeInput( - g, reshape_op, reshape_in, "X", input_scale, is_input_unsigned); - - bool is_output_unsigned{false}; - auto output_scale = GetScaleValueForNode(reshape_out, &is_output_unsigned); - DequantizeOutput( - g, reshape_op, reshape_out, "Out", output_scale, is_output_unsigned); - - ++quantize_reshape_count; - }; - - gpd(graph, handler); - AddStatis(quantize_reshape_count); - LogQuantizedOpsCounter("reshape2", quantize_reshape_count); -} - -void CPUQuantizePass::QuantizeSlice(Graph* graph) const { - GraphPatternDetector gpd; - auto pattern = gpd.mutable_pattern(); - patterns::Slice slice_pattern{pattern, name_scope_}; - slice_pattern(); - - int quantize_slice_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - VLOG(4) << "Quantize slice op"; - GET_IR_NODE_FROM_SUBGRAPH(slice_op, slice_op, slice_pattern); - - // skip if should not be quantized - if (!platform::HasOpINT8DataType(slice_op->Op())) { - LogQuantizationDisabled(slice_op); - return; - } - GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, slice_pattern); - GET_IR_NODE_FROM_SUBGRAPH(slice_in, slice_in, slice_pattern); - GET_IR_NODE_FROM_SUBGRAPH(slice_out, slice_out, slice_pattern); - - // skip if prev op and next op is not quantized - if (!IsOpDequantized(prev_op) && !IsOpQuantized(slice_out)) { - MarkAndLogCannotQuantizeOp(slice_op, - "No other quantizable operators nearby"); - return; - } - - if (!AreScalesPresentForNodes({slice_out})) { - MarkAndLogCannotQuantizeOp(slice_op, - "No scale available for the operator"); - return; - } - - bool is_input_unsigned{false}; - auto input_scale = GetScaleValueForNode(slice_out, &is_input_unsigned); - QuantizeInput( - g, slice_op, slice_in, "Input", input_scale, is_input_unsigned); - - bool is_output_unsigned{false}; - auto output_scale = GetScaleValueForNode(slice_out, &is_output_unsigned); - DequantizeOutput( - g, slice_op, slice_out, "Out", output_scale, is_output_unsigned); - - ++quantize_slice_count; - }; - - gpd(graph, handler); - AddStatis(quantize_slice_count); - LogQuantizedOpsCounter("slice", quantize_slice_count); + AddStatis(quantize_immutable_count); + LogQuantizedOpsCounter(immutable_type, quantize_immutable_count); } void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { @@ -915,7 +818,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { } void CPUQuantizePass::QuantizeElementwise( - Graph* graph, const std::string elementwise_type) const { + Graph* graph, const std::string& elementwise_type) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_}; @@ -1212,71 +1115,6 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const { LogQuantizedOpsCounter("fusion_lstm", quantize_count); } -void CPUQuantizePass::QuantizeNearestInterp(Graph* graph) const { - GraphPatternDetector gpd; - auto pattern = gpd.mutable_pattern(); - patterns::NearestInterp nearest_interp_pattern{pattern, name_scope_}; - nearest_interp_pattern(); - - int quantize_nearest_interp_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - VLOG(4) << "Quantize nearest_interp op"; - GET_IR_NODE_FROM_SUBGRAPH( - nearest_interp_op, nearest_interp_op, nearest_interp_pattern); - - // skip if should not be quantized - if (!platform::HasOpINT8DataType(nearest_interp_op->Op())) { - LogQuantizationDisabled(nearest_interp_op); - return; - } - GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, nearest_interp_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - nearest_interp_in, nearest_interp_in, nearest_interp_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - nearest_interp_out, nearest_interp_out, nearest_interp_pattern); - - // skip if prev op and next op is not quantized - if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(nearest_interp_out))) { - MarkAndLogCannotQuantizeOp(nearest_interp_op, - "No other quantizable operators nearby"); - return; - } - - if (!AreScalesPresentForNodes({nearest_interp_in, nearest_interp_out})) { - MarkAndLogCannotQuantizeOp(nearest_interp_op, - "No scale available for the operator"); - return; - } - - bool is_input_unsigned{false}; - auto input_scale = - GetScaleValueForNode(nearest_interp_in, &is_input_unsigned); - QuantizeInput(g, - nearest_interp_op, - nearest_interp_in, - "X", - input_scale, - is_input_unsigned); - - bool is_output_unsigned{false}; - auto output_scale = - GetScaleValueForNode(nearest_interp_out, &is_output_unsigned); - DequantizeOutput(g, - nearest_interp_op, - nearest_interp_out, - "Out", - output_scale, - is_output_unsigned); - - ++quantize_nearest_interp_count; - }; - - gpd(graph, handler); - AddStatis(quantize_nearest_interp_count); - LogQuantizedOpsCounter("nearest_interp", quantize_nearest_interp_count); -} - void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Quantizing the graph."; PADDLE_ENFORCE_NOT_NULL( @@ -1293,18 +1131,19 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { QuantizePool(graph); QuantizeConcat(graph); QuantizePriorBox(graph); - QuantizeTranspose(graph); QuantizeFc(graph); - QuantizeReshape(graph); QuantizeMatmul(graph); + QuantizeImmutable(graph, "reshape2", "X"); + QuantizeImmutable(graph, "transpose2", "X"); + QuantizeImmutable(graph, "slice", "Input"); + QuantizeImmutable(graph, "nearest_interp", "X"); + QuantizeImmutable(graph, "nearest_interp_v2", "X"); QuantizeElementwise(graph, "elementwise_add"); QuantizeElementwise(graph, "elementwise_mul"); QuantizeElementwise(graph, "elementwise_sub"); QuantizeFusionGru(graph); QuantizeMultiGru(graph); QuantizeFusionLSTM(graph); - QuantizeSlice(graph); - QuantizeNearestInterp(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index a880907402..56909b7fe7 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -54,16 +54,15 @@ class CPUQuantizePass : public FusePassBase { void QuantizePool(Graph* graph) const; void QuantizeConcat(Graph* graph) const; void QuantizePriorBox(Graph* graph) const; - void QuantizeTranspose(Graph* graph) const; - void QuantizeReshape(Graph* graph) const; void QuantizeMatmul(Graph* graph) const; void QuantizeElementwise(Graph* graph, - const std::string elementwise_type) const; + const std::string& elementwise_type) const; void QuantizeFusionGru(Graph* graph) const; void QuantizeMultiGru(Graph* graph) const; void QuantizeFusionLSTM(Graph* graph) const; - void QuantizeSlice(Graph* graph) const; - void QuantizeNearestInterp(Graph* graph) const; + void QuantizeImmutable(Graph* graph, + const std::string& immutable_type, + const std::string& input_name) const; void QuantizeInput(Graph* g, Node* op, diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index 4fa79f6a87..322aa22c6a 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -550,55 +550,29 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) { SCALE * S8_MAX); } -TEST(CpuQuantizePass, reshape2) { TestImmutableOp("reshape2"); } +const std::vector immutables = { + "reshape2", "transpose2", "slice", "nearest_interp", "nearest_interp_v2"}; -TEST(CpuQuantizePass, reshape2BetweenNonQuantizedOp) { - TestImmutableOpBetweenNonQuantizedOp("reshape2"); -} - -TEST(CpuQuantizePass, reshape2WithManyOutputs) { - TestImmutableOpWithManyOutputs("reshape2"); -} - -TEST(CpuQuantizePass, transpose2) { TestImmutableOp("transpose2"); } - -TEST(CpuQuantizePass, transpose2BetweenNonQuantizedOp) { - TestImmutableOpBetweenNonQuantizedOp("transpose2"); -} - -TEST(CpuQuantizePass, transpose2WithManyOutputs) { - TestImmutableOpWithManyOutputs("transpose2"); -} - -TEST(CpuQuantizePass, slice) { TestImmutableOp("slice"); } - -TEST(CpuQuantizePass, sliceBetweenNonQuantizedOp) { - TestImmutableOpBetweenNonQuantizedOp("slice"); -} - -TEST(CpuQuantizePass, sliceWithManyOutputs) { - TestImmutableOpWithManyOutputs("slice"); -} +class TestImmutables : public testing::TestWithParam {}; -TEST(CpuQuantizePass, nearestInterp) { TestImmutableOp("nearest_interp"); } - -TEST(CpuQuantizePass, nearestInterpBetweenNonQuantizedOp) { - TestImmutableOpBetweenNonQuantizedOp("nearest_interp"); -} +TEST_P(TestImmutables, immutable_basic) { TestImmutableOp(GetParam()); } -TEST(CpuQuantizePass, nearestInterpWithManyOutputs) { - TestImmutableOpWithManyOutputs("nearest_interp"); +TEST_P(TestImmutables, immutable_between_non_quantized) { + TestImmutableOpBetweenNonQuantizedOp(GetParam()); } -TEST(CpuQuantizePass, nearestInterpV2) { TestImmutableOp("nearest_interp_v2"); } - -TEST(CpuQuantizePass, nearestInterpV2BetweenNonQuantizedOp) { - TestImmutableOpBetweenNonQuantizedOp("nearest_interp_v2"); +TEST_P(TestImmutables, immutable_many_outputs) { + TestImmutableOpWithManyOutputs(GetParam()); } -TEST(CpuQuantizePass, nearestInterpV2WithManyOutputs) { - TestImmutableOpWithManyOutputs("nearest_interp_v2"); -} +INSTANTIATE_TEST_CASE_P( + CpuQuantizePass, + TestImmutables, + testing::ValuesIn(immutables), + [](const ::testing::TestParamInfo& info) { + std::string name = info.param; + return name; + }); static const std::initializer_list variable_names_matmul = { "a", "b", "c", "d", "e", "f"}; @@ -735,7 +709,7 @@ TEST_P(TestElementwises, elementwise_unsigned_and_signed_input) { } INSTANTIATE_TEST_CASE_P( - Elementwises, + CpuQuantizePass, TestElementwises, testing::ValuesIn(elementwises), [](const ::testing::TestParamInfo& info) { -- GitLab