From 130db92ac28fa5a325d0366c3ad60c43021bf205 Mon Sep 17 00:00:00 2001 From: Paulina Gacek Date: Tue, 8 Nov 2022 04:11:15 +0100 Subject: [PATCH] Split quant (#47449) * Split kernel registered, tests for uint/int added * Split quantized * Split output scales calculated only once * NearestInterp test fix reversed * DequantizeOutputs corrected --- .../framework/ir/graph_pattern_detector.h | 2 +- .../compute_propagate_scales_mkldnn_pass.cc | 3 +- .../framework/ir/mkldnn/cpu_quantize_pass.cc | 67 +++++++++++++-- .../framework/ir/mkldnn/cpu_quantize_pass.h | 8 ++ .../ir/mkldnn/cpu_quantize_pass_tester.cc | 11 ++- .../ir/mkldnn/cpu_quantize_placement_pass.cc | 3 +- .../fluid/inference/api/mkldnn_quantizer.cc | 2 +- .../inference/api/mkldnn_quantizer_config.cc | 3 + .../inference/api/paddle_analysis_config.h | 3 +- paddle/fluid/operators/split_op.cc | 2 +- paddle/phi/kernels/onednn/split_kernel.cc | 14 ++- .../quantization/quant2_int8_mkldnn_pass.py | 2 + .../unittests/mkldnn/test_split_mkldnn_op.py | 85 +++++++++++++++---- 13 files changed, 171 insertions(+), 34 deletions(-) mode change 100755 => 100644 paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index a4ee0a0983..0ec4e0c276 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1112,7 +1112,7 @@ struct ResidualElementwise : public PatternBase { }; // General struct for immutable ops: -// reshape, transpose, slice, shape, nearest-interp +// reshape, transpose, slice, shape, nearest-interp, split // Forward pass for no weights-op. // immutable_out is a result of the operator. struct Immutable : public PatternBase { diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc index 7f1bc37183..f1686e445f 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc @@ -498,7 +498,8 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const { "slice", "shape", "nearest_interp", - "nearest_interp_v2"}; + "nearest_interp_v2", + "split"}; StringPairMap var_quant_scales{}; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 80aa81b9b0..efdba4f44f 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -245,6 +245,54 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); } +void CPUQuantizePass::DequantizeOutputs(Graph* g, + Node* op, + std::string output_name, + double scale_to_one, + bool is_unsigned, + std::string scale_attr_name) const { + auto outputs = op->outputs; + PADDLE_ENFORCE_GE(outputs.size(), + 1, + platform::errors::InvalidArgument( + "OP(%s)'s outputs(%d) must be equal or greater than 1.", + op->Name(), + outputs.size())); + + std::vector quantize_in_node_names(outputs.size()); + + unsigned max = is_unsigned ? U8_MAX : S8_MAX; + float scale = scale_to_one * max; + + for (size_t i = 0; i < outputs.size(); i++) { + // Create dequantize input variable + VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); + Node* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc); + quantize_in_node_names[i] = dequantize_in_node->Name(); + + // create a dequantize op node for output. + OpDesc deq_desc; + deq_desc.SetType("dequantize"); + deq_desc.SetInput("Input", + std::vector({quantize_in_node_names[i]})); + deq_desc.SetOutput("Output", + std::vector({outputs[i]->Name()})); + deq_desc.SetAttr("Scale", scale); + deq_desc.SetAttr("is_negative_input", !is_unsigned); + auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied. + + // link dequantize op + UnlinkNodes(op, outputs[i]); + IR_NODE_LINK_TO(op, dequantize_in_node); + IR_NODE_LINK_TO(dequantize_in_node, dequantize_op); + IR_NODE_LINK_TO(dequantize_op, outputs[i]); + } + + // update op's output + op->Op()->SetOutput(output_name, quantize_in_node_names); + if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); +} + bool CPUQuantizePass::AreScalesPresentForVarNames( std::vector names) const { bool present = true; @@ -730,13 +778,17 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph, bool is_output_unsigned{false}; auto output_scale = GetScaleValueForNode(immutable_out, &is_output_unsigned); - DequantizeOutput(g, - immutable_op, - immutable_out, - "Out", - output_scale, - is_output_unsigned); - + if (immutable_type == "split") { // ops with multiple outputs + DequantizeOutputs( + g, immutable_op, "Out", output_scale, is_output_unsigned); + } else { + DequantizeOutput(g, + immutable_op, + immutable_out, + "Out", + output_scale, + is_output_unsigned); + } ++quantize_immutable_count; }; @@ -1184,6 +1236,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { QuantizeImmutable(graph, "slice", "Input"); QuantizeImmutable(graph, "nearest_interp", "X"); QuantizeImmutable(graph, "nearest_interp_v2", "X"); + QuantizeImmutable(graph, "split", "X"); QuantizeElementwise(graph, "elementwise_add"); QuantizeElementwise(graph, "elementwise_mul"); QuantizeElementwise(graph, "elementwise_sub"); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index 522b57eb13..64f9b11ee9 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -91,6 +91,14 @@ class CPUQuantizePass : public FusePassBase { bool is_unsigned, std::string scale_attr_name = "") const; + // quantize all outputs of given name + void DequantizeOutputs(Graph* g, + Node* op, + std::string output_name, + double scale_to_one, + bool is_unsigned, + std::string scale_attr_name = "") const; + bool AreScalesPresentForVarNames(std::vector names) const; bool AreScalesPresentForNodes(std::initializer_list nodes) const; std::pair GetScaleDataByName( diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc old mode 100755 new mode 100644 index 201c2be160..1be8a6ca44 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -69,6 +69,9 @@ void SetOp(ProgramDesc* prog, } else if (type == "slice") { op->SetInput("Input", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); + } else if (type == "split") { + op->SetInput("X", {inputs[0]}); + op->SetOutput("Out", {outputs}); } else if (type == "dropout") { op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); @@ -556,8 +559,12 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) { SCALE * S8_MAX); } -const std::vector immutables = { - "reshape2", "transpose2", "slice", "nearest_interp", "nearest_interp_v2"}; +const std::vector immutables = {"reshape2", + "transpose2", + "slice", + "nearest_interp", + "nearest_interp_v2", + "split"}; class TestImmutables : public testing::TestWithParam {}; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc index 7e0388cb80..70433772ce 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc @@ -42,7 +42,8 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { "fusion_gru", "fusion_lstm", "multi_gru", - "slice"}); + "slice", + "split"}); const auto& excluded_ids_list = Get>("quantize_excluded_op_ids"); const auto& op_types_list = diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 017933240a..27bbdc0bbf 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -131,7 +131,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs( is_unsigned = true; } else if (op->Type() == "transpose2" || op->Type() == "reshape2" || op->Type() == "pool2d" || op->Type() == "nearest_interp" || - op->Type() == "nearest_interp_v2") { + op->Type() == "nearest_interp_v2" || op->Type() == "split") { auto input_var_name = op->Input("X")[0]; PADDLE_ENFORCE_NE(scales_.find(input_var_name), scales_.end(), diff --git a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc index bfe6c5a947..0beac10903 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc @@ -48,6 +48,9 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { rules_["shape"]["Input"] = ScaleAlgo::KL; rules_["shape"]["Out"] = ScaleAlgo::NONE; + rules_["split"]["X"] = ScaleAlgo::KL; + rules_["split"]["Out"] = ScaleAlgo::NONE; + rules_["fc"]["Input"] = ScaleAlgo::KL; rules_["fc"]["W"] = ScaleAlgo::MAX_CH_T; rules_["fc"]["Bias"] = ScaleAlgo::NONE; diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index f9587c0995..99e8ddac04 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -1134,7 +1134,8 @@ struct PD_INFER_DECL AnalysisConfig { "fusion_gru", "fusion_lstm", "multi_gru", - "slice"}; + "slice", + "split"}; // ipu related. bool use_ipu_{false}; diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index 9efef84b1d..68eab9365a 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -198,7 +198,7 @@ Example: "mkldnn_data_type", "(string, default \"float32\"). Data type of mkldnn kernel") .SetDefault("float32") - .InEnum({"float32", "bfloat16"}); + .InEnum({"float32", "bfloat16", "int8", "uint8"}); } }; diff --git a/paddle/phi/kernels/onednn/split_kernel.cc b/paddle/phi/kernels/onednn/split_kernel.cc index 057aa7325f..062233b198 100644 --- a/paddle/phi/kernels/onednn/split_kernel.cc +++ b/paddle/phi/kernels/onednn/split_kernel.cc @@ -77,12 +77,20 @@ void SplitWithNumKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - split, OneDNN, ONEDNN, phi::SplitKernel, float, phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(split, + OneDNN, + ONEDNN, + phi::SplitKernel, + float, + phi::dtype::bfloat16, + int8_t, + uint8_t) {} PD_REGISTER_KERNEL(split_with_num, OneDNN, ONEDNN, phi::SplitWithNumKernel, float, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + int8_t, + uint8_t) {} diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index c7723097f4..fcc2daff20 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -74,6 +74,7 @@ class Quant2Int8MkldnnPass(object): 'shape', 'nearest_interp', 'nearest_interp_v2', + 'split', ] self._scale_ops = ['scale'] self._conv_ops = ['conv2d', 'depthwise_conv2d'] @@ -284,6 +285,7 @@ class Quant2Int8MkldnnPass(object): self._var_quant_scales[ input_name ] = self._var_quant_scales[output_name] + elif op.name() == 'concat': output_name = op.output("Out")[0] if output_name in self._var_quant_scales: diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py index e148e0cdca..79ab0e2f01 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py @@ -19,11 +19,27 @@ from paddle.fluid.tests.unittests.op_test import OpTest class TestSplitSectionsOneDNNOp(OpTest): - def init_data(self): - self.x = np.random.random((4, 5, 6)).astype("float32") + def init_data_type(self): + self.dtype = np.float32 + + def init_x(self): + if self.dtype == np.float32: + self.x = np.random.random(self.input_shape).astype(self.dtype) + elif self.dtype == np.int8: + self.x = np.random.randint(-5, 5, self.input_shape).astype( + self.dtype + ) + else: # uint8 + self.x = np.random.randint(0, 10, self.input_shape).astype( + self.dtype + ) + + def init_test_case(self): + self.input_shape = (4, 5, 6) + self.init_x() self.axis = 1 + self.num = 0 self.sections = [2, 1, 2] - indices_or_sections = [2, 3] # sections np_sections = [2, 3] self.out = np.split(self.x, np_sections, self.axis) @@ -31,8 +47,8 @@ class TestSplitSectionsOneDNNOp(OpTest): self.op_type = "split" self.axis_tensor = None self.sections_tensor_list = None - self.num = 0 - self.init_data() + self.init_data_type() + self.init_test_case() self.inputs = {'X': self.x} self.attrs = {'use_mkldnn': True, 'num': self.num} @@ -58,11 +74,12 @@ class TestSplitSectionsOneDNNOp(OpTest): # test with attr(num) class TestSplitNumOneDNNOp(TestSplitSectionsOneDNNOp): - def init_data(self): - self.x = np.random.random((4, 8, 5, 3)).astype("float32") + def init_test_case(self): + self.input_shape = (4, 8, 5, 3) + self.init_x() self.axis = 1 - self.sections = [] self.num = 4 + self.sections = [] indices_or_sections = 4 # indices self.out = np.split(self.x, indices_or_sections, self.axis) @@ -71,20 +88,23 @@ class TestSplitNumOneDNNOp(TestSplitSectionsOneDNNOp): class TestSplitNumAxisTensorOneDNNOp(TestSplitSectionsOneDNNOp): - def init_data(self): - self.x = np.random.random((4, 5, 6)).astype("float32") + def init_test_case(self): + self.input_shape = (4, 5, 6) + self.init_x() + self.num = 3 self.axis = None self.sections = [] - self.num = 3 - indices_or_sections = 3 # indices self.axis_tensor = np.array([2]).astype("int32") + indices_or_sections = 3 # indices self.out = np.split(self.x, indices_or_sections, 2) # attr(sections) is list containing Tensor class TestSplitSectionsTensorOneDNNOp(TestSplitSectionsOneDNNOp): - def init_data(self): - self.x = np.random.random((4, 5, 6)).astype("float32") + def init_test_case(self): + self.input_shape = (4, 5, 6) + self.init_x() + self.num = 0 self.axis = 1 self.sections = [2, 1, 2] self.sections_tensor_list = [] @@ -98,14 +118,47 @@ class TestSplitSectionsTensorOneDNNOp(TestSplitSectionsOneDNNOp): class TestSplitOpUnknownSectionOneDNNOp(TestSplitSectionsOneDNNOp): - def init_data(self): - self.x = np.random.random((4, 5, 6)).astype("float32") + def init_test_case(self): + self.input_shape = (4, 5, 6) + self.init_x() + self.num = 0 self.axis = 2 self.sections = [2, 2, -1] indices_or_sections = [2, 4] # sections self.out = np.split(self.x, indices_or_sections, self.axis) +def create_test_class(parent): + ''' + Create int8 and uint8 versions for each test. Parent tests work by default on fp32. + ''' + + class TestInt8Case(parent): + def init_data_type(self): + self.dtype = np.int8 + + def test_check_grad(self): + pass + + class TestUint8Case(parent): + def init_data_type(self): + self.dtype = np.uint8 + + def test_check_grad(self): + pass + + TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8") + TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8") + globals()[TestInt8Case.__name__] = TestUint8Case + globals()[TestUint8Case.__name__] = TestInt8Case + + +create_test_class(TestSplitNumOneDNNOp) +create_test_class(TestSplitNumAxisTensorOneDNNOp) +create_test_class(TestSplitSectionsTensorOneDNNOp) +create_test_class(TestSplitOpUnknownSectionOneDNNOp) +create_test_class(TestSplitSectionsOneDNNOp) + if __name__ == '__main__': paddle.enable_static() unittest.main() -- GitLab