未验证 提交 bd68761a 编写于 作者: W Wangzheee 提交者: GitHub

[ pass_enhance ]quant_conv2d_dequant_fuse_pass (#33737)

上级 3ad6630f
...@@ -21,11 +21,209 @@ ...@@ -21,11 +21,209 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
QuantDequantFusePass::QuantDequantFusePass() {
AddOpCompat(OpCompat("fake_quantize_range_abs_max"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("InScale")
.IsTensor()
.End()
.AddInput("Iter")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("OutScale")
.IsTensor()
.End()
.AddOutput("OutScales")
.IsTensor()
.End()
.AddAttr("window_size")
.IsType<int>()
.IsNumGT(0)
.End()
.AddAttr("bit_length")
.IsIntIn({8, 16})
.End();
AddOpCompat(OpCompat("fake_quantize_moving_average_abs_max"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("InScale")
.IsTensor()
.End()
.AddInput("InAccum")
.IsTensor()
.IsOptional()
.End()
.AddInput("InState")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("OutScale")
.IsTensor()
.End()
.AddOutput("OutState")
.IsTensor()
.IsOptional()
.End()
.AddOutput("OutAccum")
.IsTensor()
.IsOptional()
.End()
.AddAttr("moving_rate")
.IsType<float>()
.IsNumGT(0.0f)
.End()
.AddAttr("bit_length")
.IsIntIn({8, 16})
.End();
AddOpCompat(OpCompat("fake_dequantize_max_abs"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("max_range")
.IsType<float>()
.IsNumGT(0.0f)
.End();
AddOpCompat(OpCompat("fake_channel_wise_dequantize_max_abs"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scales") // "Scales" is a vector with at most two tensors
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("quant_bits")
.IsType<std::vector<int>>()
.End()
.AddAttr("quant_axis")
.IsIntIn({0, 1})
.IsOptional()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End();
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
// Delete quant op before quantized ops, and set input scale in the attr of // Delete quant op before quantized ops, and set input scale in the attr of
// quantized ops // quantized ops
void DeleteQuant(ir::Graph* graph, Scope* scope, void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
const std::string& quant_type) { const std::string& quant_type) const {
const std::string pattern_name = "delete_quant_fuse"; const std::string pattern_name = "delete_quant_fuse";
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* input_act_node = gpd.mutable_pattern() auto* input_act_node = gpd.mutable_pattern()
...@@ -41,6 +239,10 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -41,6 +239,10 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
// ops linked from it // ops linked from it
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
subgraph.count(input_act_node), true, subgraph.count(input_act_node), true,
platform::errors::NotFound( platform::errors::NotFound(
...@@ -103,9 +305,9 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -103,9 +305,9 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
// Delete dequant op after quantized ops, and convert weight from fp32 range to // Delete dequant op after quantized ops, and convert weight from fp32 range to
// int8 range // int8 range
void FuseDequant(ir::Graph* graph, Scope* scope, void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
const std::string& quantized_op_type, const std::string& quantized_op_type,
const std::string& dequant_type) { const std::string& dequant_type) const {
std::string weight_name = ""; std::string weight_name = "";
std::string input_name = ""; std::string input_name = "";
if (quantized_op_type == "conv2d" || if (quantized_op_type == "conv2d" ||
...@@ -142,6 +344,10 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -142,6 +344,10 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
// Create new op desc // Create new op desc
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
subgraph.count(quantized_op_input), true, subgraph.count(quantized_op_input), true,
platform::errors::NotFound("Quantized op input node(%s) did not find " platform::errors::NotFound("Quantized op input node(%s) did not find "
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <memory> #include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -25,14 +24,20 @@ namespace ir { ...@@ -25,14 +24,20 @@ namespace ir {
/// ///
/// Fuse quant + conv2d/depthwise_conv2d/mul/fc + dequant /// Fuse quant + conv2d/depthwise_conv2d/mul/fc + dequant
/// ///
class Graph;
class QuantDequantFusePass : public FusePassBase { class QuantDequantFusePass : public FusePassBase {
public: public:
QuantDequantFusePass();
virtual ~QuantDequantFusePass() {} virtual ~QuantDequantFusePass() {}
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
private:
void DeleteQuant(ir::Graph* graph, Scope* scope,
const std::string& quant_type) const;
void FuseDequant(ir::Graph* graph, Scope* scope,
const std::string& quantized_op_type,
const std::string& dequant_type) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -41,6 +41,22 @@ def { ...@@ -41,6 +41,22 @@ def {
} }
} }
extra { extra {
attrs {
name: "Input_scale"
type: FLOAT
}
attrs {
name: "quantization_type"
type: STRING
}
attrs {
name: "bit_length"
type: INT
}
attrs {
name: "out_threshold"
type: FLOAT
}
attrs { attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN type: BOOLEAN
......
type: "fake_channel_wise_dequantize_max_abs"
def {
inputs {
name: "X"
}
inputs {
name: "Scales"
}
outputs {
name: "Out"
}
attrs {
name: "quant_bits"
type: INTS
}
attrs {
name: "quant_axis"
type: INT
}
}
extra {
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
...@@ -1183,7 +1183,8 @@ class QuantizationFreezePass(object): ...@@ -1183,7 +1183,8 @@ class QuantizationFreezePass(object):
if op_node_desc.has_attr("quantization_type") and \ if op_node_desc.has_attr("quantization_type") and \
op_node_desc.attr("quantization_type") == "qat_with_weight": op_node_desc.attr("quantization_type") == "qat_with_weight":
if self._weight_quantize_type == 'channel_wise_abs_max': if self._weight_quantize_type == 'channel_wise_abs_max':
self._insert_post_channel_dequant_op(graph, op_node) self._insert_post_channel_dequant_op(graph, op_node,
quant_axis)
else: else:
self._insert_post_dequant_op(graph, op_node) self._insert_post_dequant_op(graph, op_node)
...@@ -1210,7 +1211,7 @@ class QuantizationFreezePass(object): ...@@ -1210,7 +1211,7 @@ class QuantizationFreezePass(object):
v.node] v.node]
graph.safe_remove_nodes(op_node) graph.safe_remove_nodes(op_node)
def _insert_post_channel_dequant_op(self, graph, op_node): def _insert_post_channel_dequant_op(self, graph, op_node, quant_axis):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
...@@ -1258,6 +1259,7 @@ class QuantizationFreezePass(object): ...@@ -1258,6 +1259,7 @@ class QuantizationFreezePass(object):
op_type='fake_channel_wise_dequantize_max_abs', op_type='fake_channel_wise_dequantize_max_abs',
attrs={ attrs={
'quant_bits': [self._weight_bits, self._activation_bits], 'quant_bits': [self._weight_bits, self._activation_bits],
'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}, },
inputs={ inputs={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册