未验证 提交 f9420e83 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for delete_quant_dequant_filter_op_pass, test=develop (#33705)

上级 d55f3b6f
......@@ -32,6 +32,37 @@ namespace ir {
GET_IR_NODE(quant_dequant_op_outscale); \
GET_IR_NODE(any_op2);
DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
AddOpCompat(OpCompat("fake_quantize_dequantize_abs_max"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("OutScale")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsIntIn({8, 16})
.End();
AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("OutScale")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsIntIn({8, 16})
.End()
.AddAttr("quant_axis")
.IsIntIn({0, 1})
.End();
}
// Delete quant_dequant_op, then quantize and dequantize weight
void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_quantdequant_filter_op_pattern";
......@@ -50,6 +81,11 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
Graph* g) {
GET_NODES;
if (!IsCompat(*quant_dequant_op->Op())) {
LOG(WARNING) << "quant_dequant_op in delete_quant_dequant_filter_op_pass "
"compat check failed.";
return;
}
std::unordered_set<const Node*> nodes2rm = {};
int bit_length =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length"));
......
......@@ -16,16 +16,14 @@
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class DeleteQuantDequantFilterOpPass : public FusePassBase {
public:
DeleteQuantDequantFilterOpPass();
virtual ~DeleteQuantDequantFilterOpPass() {}
protected:
......
type: "fake_channel_wise_quantize_dequantize_abs_max"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
outputs {
name: "OutScale"
}
attrs {
name: "quant_axis"
type: INT
}
attrs {
name: "bit_length"
type: INT
}
}
extra {
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
type: "fake_quantize_dequantize_abs_max"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
outputs {
name: "OutScale"
}
attrs {
name: "bit_length"
type: INT
}
}
extra {
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册