未验证 提交 5cd3b4b1 编写于 作者: G Guanghua Yu 提交者: GitHub

fix QuantizeLinear kernel and pass in QAT (#44784)

上级 e1515e40
......@@ -139,7 +139,8 @@ struct FindMovingAverageAbsMaxFunctor {
void operator()(const DeviceContext &ctx,
const framework::Tensor &in_accum,
const framework::Tensor &in_state,
const framework::Tensor &cur_scale,
const T *cur_scale,
const float rate,
framework::Tensor *out_state,
framework::Tensor *out_accum,
framework::Tensor *out_scale);
......
......@@ -93,6 +93,12 @@ class QuantizeLinearOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]});
}
}
if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1});
}
if (ctx->HasOutput("OutAccum")) {
ctx->SetOutputDim("OutAccum", {1});
}
ctx->ShareLoD("X", /*->*/ "Y");
}
......@@ -113,7 +119,25 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Y",
"(Tensor) Output of quantized low level tensor, "
"but also saved as float data type.");
AddOutput("OutScale", "(Tensor) Current scale").AsDispensable().AsExtra();
AddInput("InAccum", "Last accum.")
.AsDispensable()
.AsExtra(); // only qat use
AddInput("InState", "Last state.")
.AsDispensable()
.AsExtra(); // only qat use
AddOutput("OutState", "(Tensor) state buffer.")
.AsDispensable()
.AsExtra(); // only qat use
AddOutput("OutAccum", "(Tensor) accum buffer.")
.AsDispensable()
.AsExtra(); // only qat use
AddOutput("OutScale", "(Tensor) Current scale")
.AsDispensable()
.AsExtra(); // only qat use
AddAttr<float>("moving_rate",
"(float, default 0.9) moving rate.") // only qat use
.SetDefault(0.9)
.AsExtra();
AddAttr<int>("quant_axis",
"(int, default 0) The axis for quantization. "
"For conv2d, depthwise_conv2d, conv2d_transpose "
......@@ -154,8 +178,7 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type));
})
.AsExtra();
});
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
......
......@@ -57,10 +57,31 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
if (quant_axis < 0) {
if (!is_test) {
auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_s = out_scale->mutable_data<T>(context.GetPlace());
// training
auto* in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState");
auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()(
dev_ctx, in->data<T>(), in->numel(), out_s);
dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
auto* out_state = context.Output<framework::Tensor>("OutState");
auto* out_accum = context.Output<framework::Tensor>("OutAccum");
auto* out_scale = context.Output<framework::Tensor>("OutScale");
out_state->mutable_data<T>(context.GetPlace());
out_accum->mutable_data<T>(context.GetPlace());
out_scale->mutable_data<T>(context.GetPlace());
float moving_rate = context.Attr<float>("moving_rate");
FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
*in_accum,
*in_state,
cur_scale_data,
moving_rate,
out_state,
out_accum,
out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
} else {
......
......@@ -418,8 +418,7 @@ class PostTrainingQuantization(object):
self._update_program()
# save out_threshold for quantized ops.
if not self._onnx_format:
self._save_output_threshold()
self._save_output_threshold()
if any(op_type in self._quantizable_op_type
for op_type in self._dynamic_quantize_op_type):
......@@ -996,16 +995,23 @@ class PostTrainingQuantization(object):
'''
Save output threshold to the quantized op.
'''
self._calibration_scales = {}
def save_info(op_node, out_var_name, threshold_map, out_info_name,
quantized_type):
assert out_var_name in threshold_map, \
"The output ({}) of {} node does not have threshold.".format(
out_var_name, op_node.type)
op_node._set_attr(out_info_name, threshold_map[var_name])
op_node._set_attr("with_quant_attr", True)
if op_node.type in self._quantizable_op_type:
op._set_attr("quantization_type", quantized_type)
if self._onnx_format:
# For easy extension, every var_node set a dict to save parameters of quant.
self._calibration_scales[var_name] = {}
self._calibration_scales[var_name]['scale'] = threshold_map[
var_name]
else:
op_node._set_attr(out_info_name, threshold_map[var_name])
op_node._set_attr("with_quant_attr", True)
if op_node.type in self._quantizable_op_type:
op._set_attr("quantization_type", quantized_type)
def analysis_and_save_info(op_node, out_var_name):
argname_index = utils._get_output_name_index(op_node, out_var_name)
......
......@@ -1792,6 +1792,7 @@ class InsertQuantizeLinear(object):
equal to 0, it will quantization with per channel, else quantization with per layer.
Default is -1.
channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
moving_rate(float): the rate for 'moving average' method.
is_test(bool, optional): Whether quantization with training or not. Default is True.
"""
......@@ -1801,6 +1802,7 @@ class InsertQuantizeLinear(object):
quant_bits=8,
quant_axis=-1,
channel_wise=False,
moving_rate=0.9,
is_test=True):
self._place = place
self._scope = scope
......@@ -1808,15 +1810,16 @@ class InsertQuantizeLinear(object):
self.quant_axis = quant_axis
self.channel_wise = channel_wise
self._is_test = is_test
self._moving_rate = moving_rate
def insert_quant_op(self, graph, var_node):
def insert_quant_op(self, graph, var_node, var_name=None):
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
quant_var_node = graph.create_var_node(name=self._quantized_var_name(
var_node.name()),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
var_name = var_node.name() if not var_name else var_name
quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_name),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
if self.channel_wise:
......@@ -1828,7 +1831,7 @@ class InsertQuantizeLinear(object):
scale_var_type = var_node.type()
init_scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
scale_var_node = graph.create_persistable_node(
name=self._quantized_scale_name(var_node.name()),
name=self._quantized_scale_name(var_name),
var_type=scale_var_type,
shape=[scale_var_shape],
var_dtype=var_node.dtype())
......@@ -1851,13 +1854,39 @@ class InsertQuantizeLinear(object):
inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
outputs = {"Y": quant_var_node}
if not self._is_test:
attrs["is_test"] = self._is_test
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
scale_out_node = graph.create_var_node_from_desc(
scale_var_node.var())
state_in_node = graph.create_persistable_node(
name=unique_name.generate('state'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(state_in_node, np.ones([1], dtype=data_type),
self._scope, self._place)
accum_in_node = graph.create_persistable_node(
name=unique_name.generate('accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
_init_var_node(accum_in_node, np.ones([1], dtype=data_type),
self._scope, self._place)
state_out_node = graph.create_var_node_from_desc(
state_in_node.var())
accum_out_node = graph.create_var_node_from_desc(
accum_in_node.var())
outputs["OutScale"] = scale_out_node
inputs['InState'] = state_in_node
inputs['InAccum'] = accum_in_node
outputs['OutState'] = state_out_node
outputs['OutAccum'] = accum_out_node
attrs["is_test"] = self._is_test
attrs['moving_rate'] = self._moving_rate
quant_op_node = graph.create_op_node(op_type="quantize_linear",
attrs=attrs,
......@@ -1870,6 +1899,10 @@ class InsertQuantizeLinear(object):
graph.link_to(zero_point_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node)
if not self._is_test:
graph.link_to(state_in_node, quant_op_node)
graph.link_to(accum_in_node, quant_op_node)
graph.link_to(quant_op_node, state_out_node)
graph.link_to(quant_op_node, accum_out_node)
graph.link_to(quant_op_node, scale_out_node)
return quant_var_node, scale_var_node
......@@ -1898,8 +1931,7 @@ class InsertQuantizeLinear(object):
inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
if not self._is_test:
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
quant_op_node = graph.create_op_node(op_type="dequantize_linear",
attrs=attrs,
......@@ -1938,10 +1970,10 @@ class InsertQuantizeLinear(object):
return "%s@zero_point" % (var_name)
class QuantizationTransformPassV2(object):
class QuantizationTransformPassV2(QuantizationTransformPass):
"""
Quantize the ops that have weights. Add quant and dequant ops for
the quantized ops's inputs.
the quantized ops's inputs. It is used in the new format of quantization.
"""
def __init__(self,
......@@ -2137,13 +2169,13 @@ class QuantizationTransformPassV2(object):
if is_weight and self._weight_quantize_func is not None:
target_out_node = self._insert_func(
graph, self._weight_quantize_func, var_node, op)
processed_vars.append(name)
self.processed_vars.append(name)
continue
elif not is_weight and self._act_quantize_func is not None:
target_out_node = self._insert_func(graph,
self._act_quantize_func,
var_node, op)
processed_vars.append(name)
self.processed_vars.append(name)
continue
quant_bits = self._weight_bits if var_node.name() in self.persistable_vars \
......@@ -2162,9 +2194,10 @@ class QuantizationTransformPassV2(object):
quant_bits=quant_bits,
quant_axis=quant_axis,
channel_wise=channel_wise,
moving_rate=self._moving_rate,
is_test=self._is_test)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, var_node)
graph, var_node, var_name=name)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node)
......@@ -2189,24 +2222,6 @@ class QuantizationTransformPassV2(object):
has_weight = True
return has_weight
def _is_skip_quant(self, graph, op_node):
"""
Analyse whether the op node skips quantization.
"""
is_skip = False
if op_node.op().has_attr("skip_quant") and \
op_node.op().attr("skip_quant"):
is_skip = True
# if the inputs of mul and matmul are not all persistable, use
# AddQuantDequantPassV2 to quantize them.
if op_node.name() in ["mul", "matmul", "matmul_v2"] and \
_is_input_all_not_persistable(graph, op_node):
is_skip = True
if op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_without_weight":
is_skip = True
return is_skip
def apply(self, graph):
"""
Quantize the graph for training process. According to weight and
......@@ -2257,7 +2272,7 @@ class QuantizationTransformPassV2(object):
class AddQuantDequantPassV2(object):
"""
Quantize the ops that do not have weights, and add quant_linear and dequant_linear
op for the quantized ops's inputs.
op for the quantized ops's inputs. It is used in the new format of quantization.
"""
# To be compatible with PaddleSlim, not remove _activation_type for now
......@@ -2384,6 +2399,7 @@ class AddQuantDequantPassV2(object):
quant_bits=self._quant_bits,
quant_axis=-1,
channel_wise=False,
moving_rate=self._moving_rate,
is_test=self._is_test)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, in_node)
......
......@@ -550,18 +550,41 @@ class TestquantizeOpTrain(TestquantizeOp):
def setUp(self):
self.set_args()
self.op_type = "quantize_linear"
x = np.random.randn(31, 65).astype(self.data_type)
yq, scale = quantize_max_abs(x, self.max_range)
scale = np.array(scale).astype(self.data_type)
zero_point = np.zeros(scale.shape, dtype="int32")
self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
self.attrs = {
'bit_length': self.bit_length,
'quant_axis': self.quant_axis,
'moving_rate': 0.9,
'is_test': self.is_test
}
self.outputs = {'Y': yq, 'OutScale': scale}
x = np.random.randn(31, 65).astype(self.data_type)
scale = np.array([0.001]).astype(self.data_type)
zero_point = np.zeros(scale.shape, dtype="int32")
in_accum = np.ones(1).astype(self.data_type)
in_state = np.ones(1).astype(self.data_type)
out_accum = np.zeros(1).astype(self.data_type)
out_state = np.zeros(1).astype(self.data_type)
out_accum[0] = self.attrs['moving_rate'] * in_accum[0] + np.max(
np.abs(x))
out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0
out_scale = out_accum / out_state
round_out = np.round(x / out_scale * self.max_range)
quant_data = np.clip(round_out, -self.max_range - 1, self.max_range)
self.inputs = {
'X': x,
'Scale': scale,
'ZeroPoint': zero_point,
'InAccum': in_accum,
'InState': in_state,
}
self.outputs = {
'Y': quant_data,
'OutScale': out_scale,
'OutAccum': out_accum,
'OutState': out_state,
}
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册