From bd68761a0ed239bd6d0af7656f0956783d33e129 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Fri, 25 Jun 2021 16:27:36 +0800 Subject: [PATCH] [ pass_enhance ]quant_conv2d_dequant_fuse_pass (#33737) --- .../ir/quant_conv2d_dequant_fuse_pass.cc | 218 +++++++++++++++++- .../ir/quant_conv2d_dequant_fuse_pass.h | 11 +- paddle/fluid/operators/compat/conv2d.pbtxt | 16 ++ ...fake_channel_wise_dequantize_max_abs.pbtxt | 47 ++++ .../slim/quantization/quantization_pass.py | 6 +- 5 files changed, 287 insertions(+), 11 deletions(-) create mode 100644 paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 2fc39fd25d5..a092c894d9e 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -21,11 +21,209 @@ namespace paddle { namespace framework { namespace ir { - +QuantDequantFusePass::QuantDequantFusePass() { + AddOpCompat(OpCompat("fake_quantize_range_abs_max")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("InScale") + .IsTensor() + .End() + .AddInput("Iter") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("OutScale") + .IsTensor() + .End() + .AddOutput("OutScales") + .IsTensor() + .End() + .AddAttr("window_size") + .IsType() + .IsNumGT(0) + .End() + .AddAttr("bit_length") + .IsIntIn({8, 16}) + .End(); + AddOpCompat(OpCompat("fake_quantize_moving_average_abs_max")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("InScale") + .IsTensor() + .End() + .AddInput("InAccum") + .IsTensor() + .IsOptional() + .End() + .AddInput("InState") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("OutScale") + .IsTensor() + .End() + .AddOutput("OutState") + .IsTensor() + .IsOptional() + .End() + .AddOutput("OutAccum") + .IsTensor() + .IsOptional() + .End() + .AddAttr("moving_rate") + .IsType() + .IsNumGT(0.0f) + .End() + .AddAttr("bit_length") + .IsIntIn({8, 16}) + .End(); + AddOpCompat(OpCompat("fake_dequantize_max_abs")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("max_range") + .IsType() + .IsNumGT(0.0f) + .End(); + AddOpCompat(OpCompat("fake_channel_wise_dequantize_max_abs")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scales") // "Scales" is a vector with at most two tensors + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("quant_bits") + .IsType>() + .End() + .AddAttr("quant_axis") + .IsIntIn({0, 1}) + .IsOptional() + .End(); + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("in_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("activation_type") + .IsStringIn({"relu", ""}) + .End(); + AddOpCompat(OpCompat("conv2d_transpose")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .IsOptional() + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); +} // Delete quant op before quantized ops, and set input scale in the attr of // quantized ops -void DeleteQuant(ir::Graph* graph, Scope* scope, - const std::string& quant_type) { +void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, + const std::string& quant_type) const { const std::string pattern_name = "delete_quant_fuse"; GraphPatternDetector gpd; auto* input_act_node = gpd.mutable_pattern() @@ -41,6 +239,10 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, // ops linked from it auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } PADDLE_ENFORCE_EQ( subgraph.count(input_act_node), true, platform::errors::NotFound( @@ -103,9 +305,9 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, // Delete dequant op after quantized ops, and convert weight from fp32 range to // int8 range -void FuseDequant(ir::Graph* graph, Scope* scope, - const std::string& quantized_op_type, - const std::string& dequant_type) { +void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, + const std::string& quantized_op_type, + const std::string& dequant_type) const { std::string weight_name = ""; std::string input_name = ""; if (quantized_op_type == "conv2d" || @@ -142,6 +344,10 @@ void FuseDequant(ir::Graph* graph, Scope* scope, // Create new op desc auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } PADDLE_ENFORCE_EQ( subgraph.count(quantized_op_input), true, platform::errors::NotFound("Quantized op input node(%s) did not find " diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h index a16dc7620b4..521e186c2be 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h @@ -16,7 +16,6 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { @@ -25,14 +24,20 @@ namespace ir { /// /// Fuse quant + conv2d/depthwise_conv2d/mul/fc + dequant /// -class Graph; - class QuantDequantFusePass : public FusePassBase { public: + QuantDequantFusePass(); virtual ~QuantDequantFusePass() {} protected: void ApplyImpl(ir::Graph* graph) const override; + + private: + void DeleteQuant(ir::Graph* graph, Scope* scope, + const std::string& quant_type) const; + void FuseDequant(ir::Graph* graph, Scope* scope, + const std::string& quantized_op_type, + const std::string& dequant_type) const; }; } // namespace ir diff --git a/paddle/fluid/operators/compat/conv2d.pbtxt b/paddle/fluid/operators/compat/conv2d.pbtxt index d8a08b6b410..9e4c8b796a8 100644 --- a/paddle/fluid/operators/compat/conv2d.pbtxt +++ b/paddle/fluid/operators/compat/conv2d.pbtxt @@ -41,6 +41,22 @@ def { } } extra { + attrs { + name: "Input_scale" + type: FLOAT + } + attrs { + name: "quantization_type" + type: STRING + } + attrs { + name: "bit_length" + type: INT + } + attrs { + name: "out_threshold" + type: FLOAT + } attrs { name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" type: BOOLEAN diff --git a/paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt b/paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt new file mode 100644 index 00000000000..542a0ff649f --- /dev/null +++ b/paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt @@ -0,0 +1,47 @@ +type: "fake_channel_wise_dequantize_max_abs" +def { + inputs { + name: "X" + } + inputs { + name: "Scales" + } + outputs { + name: "Out" + } + attrs { + name: "quant_bits" + type: INTS + } + attrs { + name: "quant_axis" + type: INT + } +} +extra { + attrs { + name: "is_test" + type: BOOLEAN + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} + diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index fb69e29f340..010c6a67a3a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1183,7 +1183,8 @@ class QuantizationFreezePass(object): if op_node_desc.has_attr("quantization_type") and \ op_node_desc.attr("quantization_type") == "qat_with_weight": if self._weight_quantize_type == 'channel_wise_abs_max': - self._insert_post_channel_dequant_op(graph, op_node) + self._insert_post_channel_dequant_op(graph, op_node, + quant_axis) else: self._insert_post_dequant_op(graph, op_node) @@ -1210,7 +1211,7 @@ class QuantizationFreezePass(object): v.node] graph.safe_remove_nodes(op_node) - def _insert_post_channel_dequant_op(self, graph, op_node): + def _insert_post_channel_dequant_op(self, graph, op_node, quant_axis): persistable_vars = [p.name() for p in graph.all_persistable_nodes()] for var_node in op_node.inputs: name = var_node.name() @@ -1258,6 +1259,7 @@ class QuantizationFreezePass(object): op_type='fake_channel_wise_dequantize_max_abs', attrs={ 'quant_bits': [self._weight_bits, self._activation_bits], + 'quant_axis': quant_axis, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, inputs={ -- GitLab