未验证 提交 5cd3b4b1 编写于 作者: G Guanghua Yu 提交者: GitHub

fix QuantizeLinear kernel and pass in QAT (#44784)

上级 e1515e40
...@@ -139,7 +139,8 @@ struct FindMovingAverageAbsMaxFunctor { ...@@ -139,7 +139,8 @@ struct FindMovingAverageAbsMaxFunctor {
void operator()(const DeviceContext &ctx, void operator()(const DeviceContext &ctx,
const framework::Tensor &in_accum, const framework::Tensor &in_accum,
const framework::Tensor &in_state, const framework::Tensor &in_state,
const framework::Tensor &cur_scale, const T *cur_scale,
const float rate,
framework::Tensor *out_state, framework::Tensor *out_state,
framework::Tensor *out_accum, framework::Tensor *out_accum,
framework::Tensor *out_scale); framework::Tensor *out_scale);
......
...@@ -93,6 +93,12 @@ class QuantizeLinearOp : public framework::OperatorWithKernel { ...@@ -93,6 +93,12 @@ class QuantizeLinearOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]}); ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]});
} }
} }
if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1});
}
if (ctx->HasOutput("OutAccum")) {
ctx->SetOutputDim("OutAccum", {1});
}
ctx->ShareLoD("X", /*->*/ "Y"); ctx->ShareLoD("X", /*->*/ "Y");
} }
...@@ -113,7 +119,25 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -113,7 +119,25 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Y", AddOutput("Y",
"(Tensor) Output of quantized low level tensor, " "(Tensor) Output of quantized low level tensor, "
"but also saved as float data type."); "but also saved as float data type.");
AddOutput("OutScale", "(Tensor) Current scale").AsDispensable().AsExtra(); AddInput("InAccum", "Last accum.")
.AsDispensable()
.AsExtra(); // only qat use
AddInput("InState", "Last state.")
.AsDispensable()
.AsExtra(); // only qat use
AddOutput("OutState", "(Tensor) state buffer.")
.AsDispensable()
.AsExtra(); // only qat use
AddOutput("OutAccum", "(Tensor) accum buffer.")
.AsDispensable()
.AsExtra(); // only qat use
AddOutput("OutScale", "(Tensor) Current scale")
.AsDispensable()
.AsExtra(); // only qat use
AddAttr<float>("moving_rate",
"(float, default 0.9) moving rate.") // only qat use
.SetDefault(0.9)
.AsExtra();
AddAttr<int>("quant_axis", AddAttr<int>("quant_axis",
"(int, default 0) The axis for quantization. " "(int, default 0) The axis for quantization. "
"For conv2d, depthwise_conv2d, conv2d_transpose " "For conv2d, depthwise_conv2d, conv2d_transpose "
...@@ -154,8 +178,7 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -154,8 +178,7 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"nearest ties to even and 1 is rounding to nearest " "nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d", "ties away from zero.but the received is %d",
round_type)); round_type));
}) });
.AsExtra();
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
......
...@@ -57,10 +57,31 @@ class QuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -57,10 +57,31 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
if (quant_axis < 0) { if (quant_axis < 0) {
if (!is_test) { if (!is_test) {
auto* out_scale = context.Output<framework::Tensor>("OutScale"); // training
T* out_s = out_scale->mutable_data<T>(context.GetPlace()); auto* in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState");
auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()( FindAbsMaxFunctor<DeviceContext, T>()(
dev_ctx, in->data<T>(), in->numel(), out_s); dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
auto* out_state = context.Output<framework::Tensor>("OutState");
auto* out_accum = context.Output<framework::Tensor>("OutAccum");
auto* out_scale = context.Output<framework::Tensor>("OutScale");
out_state->mutable_data<T>(context.GetPlace());
out_accum->mutable_data<T>(context.GetPlace());
out_scale->mutable_data<T>(context.GetPlace());
float moving_rate = context.Attr<float>("moving_rate");
FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
*in_accum,
*in_state,
cur_scale_data,
moving_rate,
out_state,
out_accum,
out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()( ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, round_type, out); dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
} else { } else {
......
...@@ -418,8 +418,7 @@ class PostTrainingQuantization(object): ...@@ -418,8 +418,7 @@ class PostTrainingQuantization(object):
self._update_program() self._update_program()
# save out_threshold for quantized ops. # save out_threshold for quantized ops.
if not self._onnx_format: self._save_output_threshold()
self._save_output_threshold()
if any(op_type in self._quantizable_op_type if any(op_type in self._quantizable_op_type
for op_type in self._dynamic_quantize_op_type): for op_type in self._dynamic_quantize_op_type):
...@@ -996,16 +995,23 @@ class PostTrainingQuantization(object): ...@@ -996,16 +995,23 @@ class PostTrainingQuantization(object):
''' '''
Save output threshold to the quantized op. Save output threshold to the quantized op.
''' '''
self._calibration_scales = {}
def save_info(op_node, out_var_name, threshold_map, out_info_name, def save_info(op_node, out_var_name, threshold_map, out_info_name,
quantized_type): quantized_type):
assert out_var_name in threshold_map, \ assert out_var_name in threshold_map, \
"The output ({}) of {} node does not have threshold.".format( "The output ({}) of {} node does not have threshold.".format(
out_var_name, op_node.type) out_var_name, op_node.type)
op_node._set_attr(out_info_name, threshold_map[var_name]) if self._onnx_format:
op_node._set_attr("with_quant_attr", True) # For easy extension, every var_node set a dict to save parameters of quant.
if op_node.type in self._quantizable_op_type: self._calibration_scales[var_name] = {}
op._set_attr("quantization_type", quantized_type) self._calibration_scales[var_name]['scale'] = threshold_map[
var_name]
else:
op_node._set_attr(out_info_name, threshold_map[var_name])
op_node._set_attr("with_quant_attr", True)
if op_node.type in self._quantizable_op_type:
op._set_attr("quantization_type", quantized_type)
def analysis_and_save_info(op_node, out_var_name): def analysis_and_save_info(op_node, out_var_name):
argname_index = utils._get_output_name_index(op_node, out_var_name) argname_index = utils._get_output_name_index(op_node, out_var_name)
......
...@@ -1792,6 +1792,7 @@ class InsertQuantizeLinear(object): ...@@ -1792,6 +1792,7 @@ class InsertQuantizeLinear(object):
equal to 0, it will quantization with per channel, else quantization with per layer. equal to 0, it will quantization with per channel, else quantization with per layer.
Default is -1. Default is -1.
channel_wise(bool, optional): Whether quantization with per channel or not. Default is False. channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
moving_rate(float): the rate for 'moving average' method.
is_test(bool, optional): Whether quantization with training or not. Default is True. is_test(bool, optional): Whether quantization with training or not. Default is True.
""" """
...@@ -1801,6 +1802,7 @@ class InsertQuantizeLinear(object): ...@@ -1801,6 +1802,7 @@ class InsertQuantizeLinear(object):
quant_bits=8, quant_bits=8,
quant_axis=-1, quant_axis=-1,
channel_wise=False, channel_wise=False,
moving_rate=0.9,
is_test=True): is_test=True):
self._place = place self._place = place
self._scope = scope self._scope = scope
...@@ -1808,15 +1810,16 @@ class InsertQuantizeLinear(object): ...@@ -1808,15 +1810,16 @@ class InsertQuantizeLinear(object):
self.quant_axis = quant_axis self.quant_axis = quant_axis
self.channel_wise = channel_wise self.channel_wise = channel_wise
self._is_test = is_test self._is_test = is_test
self._moving_rate = moving_rate
def insert_quant_op(self, graph, var_node): def insert_quant_op(self, graph, var_node, var_name=None):
assert var_node.is_var(), '{} is not a var'.format(var_node.name()) assert var_node.is_var(), '{} is not a var'.format(var_node.name())
var_name = var_node.name() if not var_name else var_name
quant_var_node = graph.create_var_node(name=self._quantized_var_name( quant_var_node = graph.create_var_node(
var_node.name()), name=self._quantized_var_name(var_name),
var_type=var_node.type(), var_type=var_node.type(),
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
if self.channel_wise: if self.channel_wise:
...@@ -1828,7 +1831,7 @@ class InsertQuantizeLinear(object): ...@@ -1828,7 +1831,7 @@ class InsertQuantizeLinear(object):
scale_var_type = var_node.type() scale_var_type = var_node.type()
init_scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type) init_scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
scale_var_node = graph.create_persistable_node( scale_var_node = graph.create_persistable_node(
name=self._quantized_scale_name(var_node.name()), name=self._quantized_scale_name(var_name),
var_type=scale_var_type, var_type=scale_var_type,
shape=[scale_var_shape], shape=[scale_var_shape],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
...@@ -1851,13 +1854,39 @@ class InsertQuantizeLinear(object): ...@@ -1851,13 +1854,39 @@ class InsertQuantizeLinear(object):
inputs["ZeroPoint"] = zero_point_node inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits} attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
outputs = {"Y": quant_var_node} outputs = {"Y": quant_var_node}
if not self._is_test: if not self._is_test:
attrs["is_test"] = self._is_test
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
scale_out_node = graph.create_var_node_from_desc( scale_out_node = graph.create_var_node_from_desc(
scale_var_node.var()) scale_var_node.var())
state_in_node = graph.create_persistable_node(
name=unique_name.generate('state'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(state_in_node, np.ones([1], dtype=data_type),
self._scope, self._place)
accum_in_node = graph.create_persistable_node(
name=unique_name.generate('accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
_init_var_node(accum_in_node, np.ones([1], dtype=data_type),
self._scope, self._place)
state_out_node = graph.create_var_node_from_desc(
state_in_node.var())
accum_out_node = graph.create_var_node_from_desc(
accum_in_node.var())
outputs["OutScale"] = scale_out_node outputs["OutScale"] = scale_out_node
inputs['InState'] = state_in_node
inputs['InAccum'] = accum_in_node
outputs['OutState'] = state_out_node
outputs['OutAccum'] = accum_out_node
attrs["is_test"] = self._is_test
attrs['moving_rate'] = self._moving_rate
quant_op_node = graph.create_op_node(op_type="quantize_linear", quant_op_node = graph.create_op_node(op_type="quantize_linear",
attrs=attrs, attrs=attrs,
...@@ -1870,6 +1899,10 @@ class InsertQuantizeLinear(object): ...@@ -1870,6 +1899,10 @@ class InsertQuantizeLinear(object):
graph.link_to(zero_point_node, quant_op_node) graph.link_to(zero_point_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node) graph.link_to(quant_op_node, quant_var_node)
if not self._is_test: if not self._is_test:
graph.link_to(state_in_node, quant_op_node)
graph.link_to(accum_in_node, quant_op_node)
graph.link_to(quant_op_node, state_out_node)
graph.link_to(quant_op_node, accum_out_node)
graph.link_to(quant_op_node, scale_out_node) graph.link_to(quant_op_node, scale_out_node)
return quant_var_node, scale_var_node return quant_var_node, scale_var_node
...@@ -1898,8 +1931,7 @@ class InsertQuantizeLinear(object): ...@@ -1898,8 +1931,7 @@ class InsertQuantizeLinear(object):
inputs["ZeroPoint"] = zero_point_node inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits} attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
if not self._is_test: attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
quant_op_node = graph.create_op_node(op_type="dequantize_linear", quant_op_node = graph.create_op_node(op_type="dequantize_linear",
attrs=attrs, attrs=attrs,
...@@ -1938,10 +1970,10 @@ class InsertQuantizeLinear(object): ...@@ -1938,10 +1970,10 @@ class InsertQuantizeLinear(object):
return "%s@zero_point" % (var_name) return "%s@zero_point" % (var_name)
class QuantizationTransformPassV2(object): class QuantizationTransformPassV2(QuantizationTransformPass):
""" """
Quantize the ops that have weights. Add quant and dequant ops for Quantize the ops that have weights. Add quant and dequant ops for
the quantized ops's inputs. the quantized ops's inputs. It is used in the new format of quantization.
""" """
def __init__(self, def __init__(self,
...@@ -2137,13 +2169,13 @@ class QuantizationTransformPassV2(object): ...@@ -2137,13 +2169,13 @@ class QuantizationTransformPassV2(object):
if is_weight and self._weight_quantize_func is not None: if is_weight and self._weight_quantize_func is not None:
target_out_node = self._insert_func( target_out_node = self._insert_func(
graph, self._weight_quantize_func, var_node, op) graph, self._weight_quantize_func, var_node, op)
processed_vars.append(name) self.processed_vars.append(name)
continue continue
elif not is_weight and self._act_quantize_func is not None: elif not is_weight and self._act_quantize_func is not None:
target_out_node = self._insert_func(graph, target_out_node = self._insert_func(graph,
self._act_quantize_func, self._act_quantize_func,
var_node, op) var_node, op)
processed_vars.append(name) self.processed_vars.append(name)
continue continue
quant_bits = self._weight_bits if var_node.name() in self.persistable_vars \ quant_bits = self._weight_bits if var_node.name() in self.persistable_vars \
...@@ -2162,9 +2194,10 @@ class QuantizationTransformPassV2(object): ...@@ -2162,9 +2194,10 @@ class QuantizationTransformPassV2(object):
quant_bits=quant_bits, quant_bits=quant_bits,
quant_axis=quant_axis, quant_axis=quant_axis,
channel_wise=channel_wise, channel_wise=channel_wise,
moving_rate=self._moving_rate,
is_test=self._is_test) is_test=self._is_test)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, var_node) graph, var_node, var_name=name)
dequant_var_node = insert_quant_pass.insert_dequant_op( dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node) graph, quant_var_node, scale_var_node)
...@@ -2189,24 +2222,6 @@ class QuantizationTransformPassV2(object): ...@@ -2189,24 +2222,6 @@ class QuantizationTransformPassV2(object):
has_weight = True has_weight = True
return has_weight return has_weight
def _is_skip_quant(self, graph, op_node):
"""
Analyse whether the op node skips quantization.
"""
is_skip = False
if op_node.op().has_attr("skip_quant") and \
op_node.op().attr("skip_quant"):
is_skip = True
# if the inputs of mul and matmul are not all persistable, use
# AddQuantDequantPassV2 to quantize them.
if op_node.name() in ["mul", "matmul", "matmul_v2"] and \
_is_input_all_not_persistable(graph, op_node):
is_skip = True
if op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_without_weight":
is_skip = True
return is_skip
def apply(self, graph): def apply(self, graph):
""" """
Quantize the graph for training process. According to weight and Quantize the graph for training process. According to weight and
...@@ -2257,7 +2272,7 @@ class QuantizationTransformPassV2(object): ...@@ -2257,7 +2272,7 @@ class QuantizationTransformPassV2(object):
class AddQuantDequantPassV2(object): class AddQuantDequantPassV2(object):
""" """
Quantize the ops that do not have weights, and add quant_linear and dequant_linear Quantize the ops that do not have weights, and add quant_linear and dequant_linear
op for the quantized ops's inputs. op for the quantized ops's inputs. It is used in the new format of quantization.
""" """
# To be compatible with PaddleSlim, not remove _activation_type for now # To be compatible with PaddleSlim, not remove _activation_type for now
...@@ -2384,6 +2399,7 @@ class AddQuantDequantPassV2(object): ...@@ -2384,6 +2399,7 @@ class AddQuantDequantPassV2(object):
quant_bits=self._quant_bits, quant_bits=self._quant_bits,
quant_axis=-1, quant_axis=-1,
channel_wise=False, channel_wise=False,
moving_rate=self._moving_rate,
is_test=self._is_test) is_test=self._is_test)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, in_node) graph, in_node)
......
...@@ -550,18 +550,41 @@ class TestquantizeOpTrain(TestquantizeOp): ...@@ -550,18 +550,41 @@ class TestquantizeOpTrain(TestquantizeOp):
def setUp(self): def setUp(self):
self.set_args() self.set_args()
self.op_type = "quantize_linear" self.op_type = "quantize_linear"
x = np.random.randn(31, 65).astype(self.data_type)
yq, scale = quantize_max_abs(x, self.max_range)
scale = np.array(scale).astype(self.data_type)
zero_point = np.zeros(scale.shape, dtype="int32")
self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
self.attrs = { self.attrs = {
'bit_length': self.bit_length, 'bit_length': self.bit_length,
'quant_axis': self.quant_axis, 'quant_axis': self.quant_axis,
'moving_rate': 0.9,
'is_test': self.is_test 'is_test': self.is_test
} }
self.outputs = {'Y': yq, 'OutScale': scale}
x = np.random.randn(31, 65).astype(self.data_type)
scale = np.array([0.001]).astype(self.data_type)
zero_point = np.zeros(scale.shape, dtype="int32")
in_accum = np.ones(1).astype(self.data_type)
in_state = np.ones(1).astype(self.data_type)
out_accum = np.zeros(1).astype(self.data_type)
out_state = np.zeros(1).astype(self.data_type)
out_accum[0] = self.attrs['moving_rate'] * in_accum[0] + np.max(
np.abs(x))
out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0
out_scale = out_accum / out_state
round_out = np.round(x / out_scale * self.max_range)
quant_data = np.clip(round_out, -self.max_range - 1, self.max_range)
self.inputs = {
'X': x,
'Scale': scale,
'ZeroPoint': zero_point,
'InAccum': in_accum,
'InState': in_state,
}
self.outputs = {
'Y': quant_data,
'OutScale': out_scale,
'OutAccum': out_accum,
'OutState': out_state,
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册