From f9420e8344174ab8280206dc6513a9e7dde95a6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Wed, 23 Jun 2021 14:16:06 +0800 Subject: [PATCH] add compat precondition for delete_quant_dequant_filter_op_pass, test=develop (#33705) --- .../ir/delete_quant_dequant_filter_op_pass.cc | 36 +++++++++++++++ .../ir/delete_quant_dequant_filter_op_pass.h | 4 +- ...nel_wise_quantize_dequantize_abs_max.pbtxt | 46 +++++++++++++++++++ .../fake_quantize_dequantize_abs_max.pbtxt | 38 +++++++++++++++ 4 files changed, 121 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/compat/fake_channel_wise_quantize_dequantize_abs_max.pbtxt create mode 100644 paddle/fluid/operators/compat/fake_quantize_dequantize_abs_max.pbtxt diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc index 4379bba638..4ce9199920 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc @@ -32,6 +32,37 @@ namespace ir { GET_IR_NODE(quant_dequant_op_outscale); \ GET_IR_NODE(any_op2); +DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() { + AddOpCompat(OpCompat("fake_quantize_dequantize_abs_max")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("OutScale") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsIntIn({8, 16}) + .End(); + AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("OutScale") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsIntIn({8, 16}) + .End() + .AddAttr("quant_axis") + .IsIntIn({0, 1}) + .End(); +} // Delete quant_dequant_op, then quantize and dequantize weight void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "delete_quantdequant_filter_op_pattern"; @@ -50,6 +81,11 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { Graph* g) { GET_NODES; + if (!IsCompat(*quant_dequant_op->Op())) { + LOG(WARNING) << "quant_dequant_op in delete_quant_dequant_filter_op_pass " + "compat check failed."; + return; + } std::unordered_set nodes2rm = {}; int bit_length = BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length")); diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h index 0409032d93..23049aac96 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h @@ -16,16 +16,14 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { namespace ir { -class Graph; - class DeleteQuantDequantFilterOpPass : public FusePassBase { public: + DeleteQuantDequantFilterOpPass(); virtual ~DeleteQuantDequantFilterOpPass() {} protected: diff --git a/paddle/fluid/operators/compat/fake_channel_wise_quantize_dequantize_abs_max.pbtxt b/paddle/fluid/operators/compat/fake_channel_wise_quantize_dequantize_abs_max.pbtxt new file mode 100644 index 0000000000..7c49da93e7 --- /dev/null +++ b/paddle/fluid/operators/compat/fake_channel_wise_quantize_dequantize_abs_max.pbtxt @@ -0,0 +1,46 @@ +type: "fake_channel_wise_quantize_dequantize_abs_max" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } + outputs { + name: "OutScale" + } + attrs { + name: "quant_axis" + type: INT + } + attrs { + name: "bit_length" + type: INT + } +} +extra { + attrs { + name: "is_test" + type: BOOLEAN + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} diff --git a/paddle/fluid/operators/compat/fake_quantize_dequantize_abs_max.pbtxt b/paddle/fluid/operators/compat/fake_quantize_dequantize_abs_max.pbtxt new file mode 100644 index 0000000000..bebb397e20 --- /dev/null +++ b/paddle/fluid/operators/compat/fake_quantize_dequantize_abs_max.pbtxt @@ -0,0 +1,38 @@ +type: "fake_quantize_dequantize_abs_max" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } + outputs { + name: "OutScale" + } + attrs { + name: "bit_length" + type: INT + } +} +extra { + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} -- GitLab