diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc index 4379bba6380c598431cce76717742dc96af3a142..4ce91999207a2b1a8ad2a3ab594aa74f9aece8e3 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc @@ -32,6 +32,37 @@ namespace ir { GET_IR_NODE(quant_dequant_op_outscale); \ GET_IR_NODE(any_op2); +DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() { + AddOpCompat(OpCompat("fake_quantize_dequantize_abs_max")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("OutScale") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsIntIn({8, 16}) + .End(); + AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("OutScale") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsIntIn({8, 16}) + .End() + .AddAttr("quant_axis") + .IsIntIn({0, 1}) + .End(); +} // Delete quant_dequant_op, then quantize and dequantize weight void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "delete_quantdequant_filter_op_pattern"; @@ -50,6 +81,11 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { Graph* g) { GET_NODES; + if (!IsCompat(*quant_dequant_op->Op())) { + LOG(WARNING) << "quant_dequant_op in delete_quant_dequant_filter_op_pass " + "compat check failed."; + return; + } std::unordered_set nodes2rm = {}; int bit_length = BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length")); diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h index 0409032d93816a2ba3121f2390aef5e59681ca9f..23049aac9622ee31609d8bf353f23a6f8ba3a6ff 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h @@ -16,16 +16,14 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { namespace ir { -class Graph; - class DeleteQuantDequantFilterOpPass : public FusePassBase { public: + DeleteQuantDequantFilterOpPass(); virtual ~DeleteQuantDequantFilterOpPass() {} protected: diff --git a/paddle/fluid/operators/compat/fake_channel_wise_quantize_dequantize_abs_max.pbtxt b/paddle/fluid/operators/compat/fake_channel_wise_quantize_dequantize_abs_max.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7c49da93e71836032f2eb8f784def337d27b4d4d --- /dev/null +++ b/paddle/fluid/operators/compat/fake_channel_wise_quantize_dequantize_abs_max.pbtxt @@ -0,0 +1,46 @@ +type: "fake_channel_wise_quantize_dequantize_abs_max" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } + outputs { + name: "OutScale" + } + attrs { + name: "quant_axis" + type: INT + } + attrs { + name: "bit_length" + type: INT + } +} +extra { + attrs { + name: "is_test" + type: BOOLEAN + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} diff --git a/paddle/fluid/operators/compat/fake_quantize_dequantize_abs_max.pbtxt b/paddle/fluid/operators/compat/fake_quantize_dequantize_abs_max.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bebb397e20bbe7dd31e4b374621c55b49b48b38e --- /dev/null +++ b/paddle/fluid/operators/compat/fake_quantize_dequantize_abs_max.pbtxt @@ -0,0 +1,38 @@ +type: "fake_quantize_dequantize_abs_max" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } + outputs { + name: "OutScale" + } + attrs { + name: "bit_length" + type: INT + } +} +extra { + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +}