未验证 提交 491b87b4 编写于 作者: G Guanghua Yu 提交者: GitHub

fix quantization clip and round Attribute (#43764)

上级 2739bd73
......@@ -26,14 +26,17 @@ namespace operators {
template <typename T>
struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale,
T max_range, const int quant_axis, framework::Tensor* out) {
void operator()(const platform::CPUDeviceContext &dev_ctx,
const framework::Tensor *in,
const framework::Tensor *scale,
T max_range,
const int quant_axis,
framework::Tensor *out) {
// Dequant op is before quantized op
// Dequantize the weight of quantized op
auto in_dims = in->dims();
const int64_t channel = in_dims[quant_axis];
const T* scale_factor = scale->data<T>();
const T *scale_factor = scale->data<T>();
if (quant_axis == 0) {
for (int64_t i = 0; i < channel; i++) {
T s = scale_factor[i];
......@@ -41,7 +44,7 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
auto &dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s / max_range;
}
} else if (quant_axis == 1) {
......@@ -51,12 +54,12 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
}
int64_t step_i = in->numel() / out_iter;
int64_t step_j = in->numel() / (out_iter * channel);
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
auto *in_data = in->data<T>();
auto *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
for (int64_t i = 0; i < out_iter; i++) {
for (int64_t j = 0; j < channel; j++) {
auto* cur_in = in_data + i * step_i + j * step_j;
auto* cur_out = out_data + i * step_i + j * step_j;
auto *cur_in = in_data + i * step_i + j * step_j;
auto *cur_out = out_data + i * step_i + j * step_j;
T s = scale_factor[j];
for (int64_t k = 0; k < step_j; k++) {
*cur_out = (*cur_in) * s / max_range;
......@@ -75,11 +78,11 @@ template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, double>;
class QuantizeLinearOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint",
"QuantizeLinear");
OP_INOUT_CHECK(
ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear");
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
int quant_axis = ctx->Attrs().Get<int>("quant_axis");
......@@ -95,7 +98,7 @@ class QuantizeLinearOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
......@@ -116,9 +119,10 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0)
.AddCustomChecker([](const int& quant_axis) {
.AddCustomChecker([](const int &quant_axis) {
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1 || quant_axis == -1, true,
quant_axis == 0 || quant_axis == 1 || quant_axis == -1,
true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
......@@ -126,8 +130,9 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
});
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
.AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
......@@ -140,13 +145,17 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3")
.SetDefault(0)
.AddCustomChecker([](const int& round_type) {
PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true,
platform::errors::InvalidArgument(
"'round_type' should be between 0 and 1, but "
"the received is %d",
round_type));
});
.AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument(
"'round_type' should be 0 or 1, 0 rounding to "
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type));
})
.AsExtra();
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
......@@ -170,14 +179,18 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
quantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker,
quantize_linear,
ops::QuantizeLinearOp,
ops::QuantizeLinearOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel<CPU, float>);
REGISTER_OPERATOR(
dequantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker,
dequantize_linear,
ops::QuantizeLinearOp,
ops::QuantizeLinearOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
......@@ -121,8 +121,7 @@ class PostTrainingQuantization(object):
algo="KL",
hist_percent=0.99999,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_round_algo='round',
round_type='TiesToEven',
round_type='round',
learning_rate=0.001,
is_full_quantize=False,
bias_correction=False,
......@@ -181,14 +180,10 @@ class PostTrainingQuantization(object):
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
weight_round_algo(str, optional): The method of converting the quantized weights
round_type(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the integer.
'adaround' is refer to https://arxiv.org/abs/2004.10568.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
learning_rate(float, optional): The learning rate of adaround method.
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
......@@ -269,10 +264,8 @@ class PostTrainingQuantization(object):
self._support_algo_type = [
'KL', 'hist', 'avg', 'mse', 'emd', 'abs_max', 'min_max'
]
assert round_type in ['TiesToEven', 'TiesAwayFromZero']
assert round_type in ['adaround', 'round']
self._round_type = round_type
assert weight_round_algo in ['adaround', 'round']
self._weight_round_algo = weight_round_algo
self._learning_rate = learning_rate
self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \
......@@ -414,7 +407,7 @@ class PostTrainingQuantization(object):
if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()
if self._weight_round_algo == 'adaround':
if self._round_type == 'adaround':
self._adaround_apply()
self._reset_activation_persistable()
......@@ -651,7 +644,6 @@ class PostTrainingQuantization(object):
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value
_logger.info("MSE searching stage ...")
distribution = np.round if self._round_type == 'TiesToEven' else utils.round_c
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
......@@ -664,9 +656,14 @@ class PostTrainingQuantization(object):
scale = s * abs_max_value
s += 0.02
bins = 2**(self._activation_bits - 1) - 1
quant_var = np.clip(distribution(var_tensor / scale * bins),
-bins - 1, bins)
quant_dequant_var = quant_var / bins * scale
if self._onnx_format:
quant_var = np.clip(distribution(var_tensor / scale * bins),
-bins - 1, bins)
quant_dequant_var = quant_var / bins * scale
else:
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
mse_loss = ((var_tensor - quant_dequant_var)**2).mean()
if mse_loss <= self._best_calibration_loss[var_name]:
self._best_calibration_loss[var_name] = mse_loss
......@@ -691,7 +688,6 @@ class PostTrainingQuantization(object):
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value
_logger.info("EMD searching stage ...")
distribution = np.round if self._round_type == 'TiesToEven' else utils.round_c
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
......@@ -704,9 +700,14 @@ class PostTrainingQuantization(object):
scale = s * abs_max_value
s += 0.02
bins = 2**(self._activation_bits - 1) - 1
quant_var = np.clip(distribution(var_tensor / scale * bins),
-bins - 1, bins)
quant_dequant_var = quant_var / bins * scale
if self._onnx_format:
quant_var = np.clip(distribution(var_tensor / scale * bins),
-bins - 1, bins)
quant_dequant_var = quant_var / bins * scale
else:
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
emd_loss = np.abs(
np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs(
np.std(var_tensor) - np.std(quant_dequant_var))
......@@ -918,8 +919,7 @@ class PostTrainingQuantization(object):
activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types,
round_type=self._round_type)
quantizable_op_type=major_quantizable_op_types)
else:
transform_pass = QuantizationTransformPassV2(
scope=self._scope,
......@@ -928,8 +928,7 @@ class PostTrainingQuantization(object):
activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types,
round_type=self._round_type)
quantizable_op_type=major_quantizable_op_types)
for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so
......@@ -946,15 +945,13 @@ class PostTrainingQuantization(object):
add_quant_dequant_pass = AddQuantDequantPass(
scope=self._scope,
place=self._place,
quantizable_op_type=minor_quantizable_op_types,
round_type=self._round_type)
quantizable_op_type=minor_quantizable_op_types)
else:
add_quant_dequant_pass = AddQuantDequantPassV2(
scope=self._scope,
place=self._place,
quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=self._is_full_quantize,
round_type=self._round_type)
is_full_quantized=self._is_full_quantize)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
......@@ -979,7 +976,6 @@ class PostTrainingQuantization(object):
place=self._place,
bias_correction=self._bias_correction,
weight_bits=self._weight_bits,
weight_round_algo=self._weight_round_algo,
round_type=self._round_type,
activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type,
......
......@@ -119,7 +119,6 @@ class QuantizationTransformPass(object):
moving_rate=0.9,
skip_pattern=['skip_quant'],
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'],
round_type='TiesToEven',
weight_quantize_func=None,
act_quantize_func=None,
weight_preprocess_func=None,
......@@ -157,10 +156,6 @@ class QuantizationTransformPass(object):
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_func(function): Function that defines how to quantize weight.
Using this can quickly test if user's quantization method works or not.
In this function, user should both define quantization function and
......@@ -211,7 +206,6 @@ class QuantizationTransformPass(object):
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._skip_pattern = skip_pattern
self._round_type = round_type
self._weight_quantize_func = weight_quantize_func
self._act_quantize_func = act_quantize_func
self._weight_preprocess_func = weight_preprocess_func
......@@ -465,12 +459,10 @@ class QuantizationTransformPass(object):
_init_var_node(scale_var_node,
np.zeros(scale_var_node.shape(), dtype=data_type),
self._scope, self._place)
round_type = 0 if self._round_type == 'TiesToEven' else 1
quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max',
attrs={
'bit_length': quant_bits,
'round_type': round_type,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node},
......@@ -525,11 +517,9 @@ class QuantizationTransformPass(object):
inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node
round_type = 0 if self._round_type == 'TiesToEven' else 1
attrs = {
'window_size': self._window_size,
'bit_length': quant_bits,
'round_type': round_type,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}
......@@ -600,10 +590,8 @@ class QuantizationTransformPass(object):
outs['OutState'] = state_out_node
outs['OutAccum'] = accum_out_node
round_type = 0 if self._round_type == 'TiesToEven' else 1
attrs = {
'bit_length': quant_bits,
'round_type': round_type,
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
......@@ -650,12 +638,10 @@ class QuantizationTransformPass(object):
_init_var_node(scale_var_node,
np.zeros(scale_var_node.shape(), dtype=data_type),
self._scope, self._place)
round_type = 0 if self._round_type == 'TiesToEven' else 1
quant_op_node = graph.create_op_node(
op_type='fake_channel_wise_quantize_abs_max',
attrs={
'bit_length': quant_bits,
'round_type': round_type,
'quant_axis': quant_axis,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
......@@ -949,8 +935,7 @@ class QuantizationFreezePass(object):
bias_correction=False,
weight_bits=8,
activation_bits=8,
weight_round_algo='round',
round_type='TiesToEven',
round_type='round',
weight_quantize_type='abs_max',
quantizable_op_type=None):
"""
......@@ -968,14 +953,10 @@ class QuantizationFreezePass(object):
https://arxiv.org/abs/1810.05723.
weight_bits(int): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
weight_round_algo(str, optional): The method of converting the quantized weights
round_type(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the integer.
'adaround' is refer to https://arxiv.org/abs/2004.10568.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained.
......@@ -991,7 +972,6 @@ class QuantizationFreezePass(object):
self._place = _get_paddle_place(place)
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._weight_round_algo = weight_round_algo
self._round_type = round_type
self._weight_quantize_type = weight_quantize_type
self._fake_quant_op_names = _fake_quant_op_list
......@@ -1039,7 +1019,7 @@ class QuantizationFreezePass(object):
scale_v = scale_v.tolist()
self._quant_var_scale_map[input_arg_name] = scale_v
# Quantize weight and restore
if self._weight_round_algo == 'round':
if self._round_type == 'round':
param_v = self._load_var(input_arg_name)
if any(
_check_grandchild_op_node(op_node, op)
......@@ -1049,7 +1029,8 @@ class QuantizationFreezePass(object):
quant_axis = 0
quantized_param_v = utils.quant_tensor(
param_v.copy(), scale_v, quant_axis,
self._weight_bits, self._round_type)
self._weight_bits)
quantized_param_v = np.round(quantized_param_v)
# Weight bias correction
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
......@@ -1058,6 +1039,7 @@ class QuantizationFreezePass(object):
scale_v,
quant_axis,
weight_bits=self._weight_bits)
quantized_param_v = np.round(quantized_param_v)
self._restore_var(input_arg_name, quantized_param_v)
self._remove_fake_quant_and_dequant_op(graph, op_node)
......@@ -1600,8 +1582,7 @@ class AddQuantDequantPass(object):
quant_bits=8,
skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False,
round_type='TiesToEven'):
is_full_quantized=False):
"""
Constructor.
......@@ -1623,10 +1604,6 @@ class AddQuantDequantPass(object):
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
quantizable_op_type.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
"""
self._scope = scope
self._place = _get_paddle_place(place)
......@@ -1634,7 +1611,6 @@ class AddQuantDequantPass(object):
self._quant_bits = quant_bits
self._is_test = None
self._skip_pattern = skip_pattern
self._round_type = round_type
if is_full_quantized:
self._quantizable_op_type = utils._act_supported_quantizable_op_type
......@@ -1769,10 +1745,8 @@ class AddQuantDequantPass(object):
outs['OutState'] = state_out_node
outs['OutAccum'] = accum_out_node
round_type = 0 if self._round_type == 'TiesToEven' else 1
attrs = {
'bit_length': quant_bits,
'round_type': round_type,
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
......@@ -1812,10 +1786,6 @@ class InsertQuantizeLinear(object):
Default is -1.
channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
is_test(bool, optional): Whether quantization with training or not. Default is True.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
"""
def __init__(self,
......@@ -1824,15 +1794,13 @@ class InsertQuantizeLinear(object):
quant_bits=8,
quant_axis=-1,
channel_wise=False,
is_test=True,
round_type='TiesToEven'):
is_test=True):
self._place = place
self._scope = scope
self.quant_bits = quant_bits
self.quant_axis = quant_axis
self.channel_wise = channel_wise
self._is_test = is_test
self._round_type = round_type
def insert_quant_op(self, graph, var_node):
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
......@@ -1875,12 +1843,7 @@ class InsertQuantizeLinear(object):
if zero_point_node is not None:
inputs["ZeroPoint"] = zero_point_node
round_type = 0 if self._round_type == 'TiesToEven' else 1
attrs = {
"quant_axis": self.quant_axis,
"bit_length": self.quant_bits,
"round_type": round_type
}
attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
outputs = {"Y": quant_var_node}
if not self._is_test:
attrs["is_test"] = self._is_test
......@@ -1985,7 +1948,6 @@ class QuantizationTransformPassV2(object):
moving_rate=0.9,
skip_pattern=['skip_quant'],
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'],
round_type='TiesToEven',
weight_quantize_func=None,
act_quantize_func=None,
weight_preprocess_func=None,
......@@ -2021,10 +1983,6 @@ class QuantizationTransformPassV2(object):
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_func(function): Function that defines how to quantize weight.
Using this can quickly test if user's quantization method works or not.
In this function, user should both define quantization function and
......@@ -2074,7 +2032,6 @@ class QuantizationTransformPassV2(object):
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._skip_pattern = skip_pattern
self._round_type = round_type
self._weight_quantize_func = weight_quantize_func
self._act_quantize_func = act_quantize_func
self._weight_preprocess_func = weight_preprocess_func
......@@ -2198,8 +2155,7 @@ class QuantizationTransformPassV2(object):
quant_bits=quant_bits,
quant_axis=quant_axis,
channel_wise=channel_wise,
is_test=self._is_test,
round_type=self._round_type)
is_test=self._is_test)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, var_node)
dequant_var_node = insert_quant_pass.insert_dequant_op(
......@@ -2307,8 +2263,7 @@ class AddQuantDequantPassV2(object):
quant_bits=8,
skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False,
round_type='TiesToEven'):
is_full_quantized=False):
"""
Args:
scope(paddle.Scope): The scope is used to initialize these new parameters.
......@@ -2328,10 +2283,6 @@ class AddQuantDequantPassV2(object):
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
quantizable_op_type.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
Examples:
.. code-block:: python
......@@ -2354,7 +2305,6 @@ class AddQuantDequantPassV2(object):
self._quant_bits = quant_bits
self._is_test = None
self._skip_pattern = skip_pattern
self._round_type = round_type
if is_full_quantized:
self._quantizable_op_type = utils._act_supported_quantizable_op_type
......@@ -2427,8 +2377,7 @@ class AddQuantDequantPassV2(object):
quant_bits=self._quant_bits,
quant_axis=-1,
channel_wise=False,
is_test=self._is_test,
round_type=self._round_type)
is_test=self._is_test)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, in_node)
dequant_var_node = insert_quant_pass.insert_dequant_op(
......@@ -2511,8 +2460,6 @@ class ReplaceFakeQuantDequantPass(object):
"quant_axis") else -1
bit_length = op.op().attr("bit_length") if op.op().has_attr(
"bit_length") else 8
round_type = op.op().attr("round_type") if op.op().has_attr(
"round_type") else 0
zero_point_node = None
quanted_node = x_node
......@@ -2534,8 +2481,7 @@ class ReplaceFakeQuantDequantPass(object):
quant_op_node = graph.create_op_node(op_type="quantize_linear",
attrs={
"quant_axis": quant_axis,
"bit_length": bit_length,
"round_type": round_type
"bit_length": bit_length
},
inputs={
"X": x_node,
......@@ -2654,11 +2600,11 @@ class QuantWeightPass(object):
param_v = self._load_var(x_node.name())
quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length")
round_type = _op.op().attr("round_type") if _op.op().has_attr(
"round_type") else 0
quantized_param_v = utils.quant_tensor(param_v.copy(), scale_v,
quant_axis, bits_length,
round_type)
quantized_param_v = utils.quant_tensor(param_v.copy(),
scale_v,
quant_axis,
bits_length,
onnx_format=True)
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
......
......@@ -321,39 +321,41 @@ def set_variable_data(scope, place, var_name, np_value):
tensor.set(np_value, place)
def round_c_single_element(val):
dtype = type(val)
if val >= 0:
return dtype(np.floor(val + 0.5))
return dtype(np.ceil(val - 0.5))
def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False):
# symmetry quant
def _clip(x, scale):
x[x > scale] = scale
x[x < -scale] = -scale
return x
# rounding to nearest ties away from zero
round_c = np.vectorize(round_c_single_element)
def quant_tensor(x,
scale,
quant_axis=0,
weight_bits=8,
round_type='TiesToEven'):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
distribution = np.round if round_type == 'TiesToEven' else round_c
bnt = (1 << (weight_bits - 1)) - 1
if isinstance(scale, list):
for i, s in enumerate(scale):
if s == 0.0:
s = 1e-8
if quant_axis == 0:
x[i] = distribution(x[i] / s * bnt)
x[i] = np.clip(x[i], -bnt - 1, bnt)
if onnx_format:
x[i] = np.round(x[i] / s * bnt)
x[i] = np.clip(x[i], -bnt - 1, bnt)
else:
x[i] = _clip(x[i], s)
x[i] = x[i] / s * bnt
else:
x[:, i] = distribution(x[:, i] / s * bnt)
x[:, i] = np.clip(x[:, i], -bnt - 1, bnt)
if onnx_format:
x[:, i] = np.round(x[:, i] / s * bnt)
x[:, i] = np.clip(x[:, i], -bnt - 1, bnt)
else:
x[:, i] = _clip(x[:, i], s)
x[:, i] = x[:, i] / s * bnt
else:
scale = 1e-8 if scale == 0.0 else scale
x = distribution(x / scale * bnt)
x = np.clip(x, -bnt - 1, bnt)
if onnx_format:
x = np.round(x / scale * bnt)
x = np.clip(x, -bnt - 1, bnt)
else:
x = _clip(x, scale)
x = x / scale * bnt
return x
......
......@@ -165,7 +165,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path,
data_path,
algo="KL",
weight_round_algo="round",
round_type="round",
quantizable_op_type=["conv2d"],
is_full_quantize=False,
is_use_cache_file=False,
......@@ -185,7 +185,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_nums=batch_nums,
algo=algo,
quantizable_op_type=quantizable_op_type,
weight_round_algo=weight_round_algo,
round_type=round_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
onnx_format=onnx_format,
......@@ -201,7 +201,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
data_url,
data_md5,
algo,
weight_round_algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -224,7 +224,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start post training quantization for {0} on {1} samples ...".
format(model_name, quant_iterations))
self.generate_quantized_model(fp32_model_path, data_path, algo,
weight_round_algo, quantizable_op_type,
round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file,
is_optimize_model, quant_iterations,
onnx_format)
......@@ -255,7 +255,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5 = "add84c754e9b792fea1fbd728d134ab7"
algo = "avg"
weight_round_algo = "round"
round_type = "round"
quantizable_op_type = ["mul", "lstm"]
is_full_quantize = False
is_use_cache_file = False
......@@ -264,7 +264,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
infer_iterations = 100
quant_iterations = 10
self.run_test(model_name, model_url, model_md5, data_name, data_url,
data_md5, algo, weight_round_algo, quantizable_op_type,
data_md5, algo, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, infer_iterations, quant_iterations)
......@@ -279,7 +279,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5 = "add84c754e9b792fea1fbd728d134ab7"
algo = "avg"
weight_round_algo = "round"
round_type = "round"
quantizable_op_type = ["mul", "lstm"]
is_full_quantize = False
is_use_cache_file = False
......@@ -295,7 +295,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
data_url,
data_md5,
algo,
weight_round_algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......
......@@ -108,7 +108,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
def generate_quantized_model(self,
model_path,
algo="KL",
weight_round_algo="round",
round_type="round",
quantizable_op_type=["conv2d"],
is_full_quantize=False,
is_use_cache_file=False,
......@@ -130,7 +130,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_nums=batch_nums,
algo=algo,
quantizable_op_type=quantizable_op_type,
weight_round_algo=weight_round_algo,
round_type=round_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
bias_correction=bias_correction,
......@@ -145,7 +145,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
data_url,
data_md5,
algo,
weight_round_algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -169,11 +169,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model_name, quant_iterations * batch_size))
self.generate_quantized_model(origin_model_path, algo,
weight_round_algo, quantizable_op_type,
is_full_quantize, is_use_cache_file,
is_optimize_model, batch_size,
quant_iterations, onnx_format,
self.generate_quantized_model(origin_model_path, algo, round_type,
quantizable_op_type, is_full_quantize,
is_use_cache_file, is_optimize_model,
batch_size, quant_iterations, onnx_format,
skip_tensor_list, bias_correction)
print("Start INT8 inference for {0} on {1} images ...".format(
......@@ -204,7 +203,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "KL"
weight_round_algo = "round"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -213,7 +212,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo,
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
......@@ -226,7 +225,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "hist"
weight_round_algo = "round"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -235,7 +234,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo,
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
......@@ -248,7 +247,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse"
weight_round_algo = "round"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -257,7 +256,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo,
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
......@@ -270,7 +269,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "emd"
weight_round_algo = "round"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -279,7 +278,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo,
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
......@@ -292,7 +291,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "avg"
weight_round_algo = "round"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -301,7 +300,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo,
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
......@@ -314,7 +313,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "abs_max"
weight_round_algo = "round"
round_type = "round"
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = True
is_use_cache_file = False
......@@ -323,7 +322,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 10
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo,
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
......@@ -336,7 +335,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse"
weight_round_algo = "adaround"
round_type = "adaround"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -350,7 +349,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
data_url,
data_md5,
algo,
weight_round_algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -369,7 +368,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "KL"
weight_round_algo = "adaround"
round_type = "adaround"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -378,7 +377,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo,
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
......@@ -391,7 +390,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse"
weight_round_algo = "round"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -405,7 +404,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
data_url,
data_md5,
algo,
weight_round_algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -425,7 +424,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse"
weight_round_algo = "round"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = True
is_use_cache_file = False
......@@ -439,7 +438,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
data_url,
data_md5,
algo,
weight_round_algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -458,7 +457,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "avg"
weight_round_algo = "round"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -472,7 +471,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
data_url,
data_md5,
algo,
weight_round_algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......
......@@ -242,7 +242,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path,
quantizable_op_type,
algo="KL",
weight_round_algo="round",
round_type="round",
is_full_quantize=False,
is_use_cache_file=False,
is_optimize_model=False,
......@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_dir=model_path,
algo=algo,
quantizable_op_type=quantizable_op_type,
weight_round_algo=weight_round_algo,
round_type=round_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
onnx_format=onnx_format,
......@@ -275,7 +275,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
def run_test(self,
model,
algo,
weight_round_algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
......@@ -299,10 +299,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size))
self.generate_quantized_model(model_cache_folder + "/model",
quantizable_op_type, algo,
weight_round_algo, is_full_quantize,
is_use_cache_file, is_optimize_model,
onnx_format)
quantizable_op_type, algo, round_type,
is_full_quantize, is_use_cache_file,
is_optimize_model, onnx_format)
print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
......@@ -330,7 +329,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_kl_mobilenetv1(self):
model = "MobileNet-V1"
algo = "KL"
weight_round_algo = "round"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
......@@ -345,7 +344,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, weight_round_algo, data_urls, data_md5s,
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
......@@ -355,7 +354,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "avg"
weight_round_algo = "round"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
......@@ -369,7 +368,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, weight_round_algo, data_urls, data_md5s,
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
......@@ -379,7 +378,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_hist_mobilenetv1(self):
model = "MobileNet-V1"
algo = "hist"
weight_round_algo = "round"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
......@@ -393,7 +392,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.03
self.run_test(model, algo, weight_round_algo, data_urls, data_md5s,
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
......@@ -403,7 +402,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
weight_round_algo = "round"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
......@@ -417,7 +416,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
is_optimize_model = False
# The accuracy diff of post-training quantization (abs_max) maybe bigger
diff_threshold = 0.05
self.run_test(model, algo, weight_round_algo, data_urls, data_md5s,
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
......@@ -427,7 +426,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_onnx_format_mobilenetv1(self):
model = "MobileNet-V1"
algo = "avg"
weight_round_algo = "round"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
......@@ -444,7 +443,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
diff_threshold = 0.05
self.run_test(model,
algo,
weight_round_algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
......
......@@ -25,7 +25,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self):
model = "ResNet-50"
algo = "min_max"
weight_round_algo = "round"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
......@@ -35,7 +35,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = False
diff_threshold = 0.025
self.run_test(model, algo, weight_round_algo, data_urls, data_md5s,
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
......@@ -45,7 +45,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
def test_post_training_resnet50(self):
model = "ResNet-50"
algo = "min_max"
weight_round_algo = "round"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
......@@ -58,7 +58,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
onnx_format = True
self.run_test(model,
algo,
weight_round_algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
......
......@@ -49,7 +49,7 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
dtype,
input_shape,
distribution,
round_type='TiesToEven'):
round_type='TiesAwayFromZero'):
input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype)
scale = np.max(np.abs(input_data))
......@@ -58,12 +58,12 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
if round_type == 'TiesToEven':
round_out = np.round(
input_data.astype(compute_type) * inv_scale * bnt)
output_data = np.clip(round_out, -bnt - 1, bnt)
self.attrs['round_type'] = 0
else:
round_out = round_c(
output_data = round_c(
input_data.astype(compute_type) * inv_scale * bnt)
self.attrs['round_type'] = 1
output_data = np.clip(round_out, -bnt - 1, bnt)
self.inputs = {'X': input_data}
self.outputs = {'Out': output_data, 'OutScale': scale}
self.dtype = dtype
......@@ -75,7 +75,7 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
def test_fake_quantize_abs_max_round1(self):
self._fake_quantize_abs_max(np.float32, (124, 240),
np.random.random,
round_type='TiesAwayFromZero')
round_type='TiesToEven')
def test_fake_quantize_abs_max_float16(self):
self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random)
......@@ -110,12 +110,12 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
if round_type == 'TiesToEven':
round_out = np.round(
input_data.astype(compute_type) / scale_broadcast * bnt)
output_data = np.clip(round_out, -bnt - 1, bnt)
self.attrs['round_type'] = 0
else:
round_out = round_c(
input_data.astype(compute_type) / scale_broadcast * bnt)
output_data = round_c(bnt * input_data.astype(compute_type) /
scale_broadcast)
self.attrs['round_type'] = 1
output_data = np.clip(round_out, -bnt - 1, bnt)
if quant_axis == 1:
scale_broadcast = np.transpose(scale_broadcast,
(1, ) + compute_axis)
......@@ -169,11 +169,15 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
round_out = np.round(
input_data.astype(compute_type) / out_scale[0] * bnt)
self.attrs['round_type'] = 0
output_data = np.clip(round_out, -bnt - 1, bnt)
else:
round_out = round_c(
input_data.astype(compute_type) / out_scale[0] * bnt)
if is_test:
clip_data = np.clip(input_data, -in_scale, in_scale)
else:
clip_data = input_data
output_data = round_c(
clip_data.astype(compute_type) / out_scale[0] * bnt)
self.attrs['round_type'] = 1
output_data = np.clip(round_out, -bnt - 1, bnt)
self.inputs = {
'X': input_data,
'Iter': np.zeros(1).astype(np.int64),
......@@ -250,7 +254,7 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
distribution,
dequantize=False,
with_gradient=False,
round_type='TiesToEven'):
round_type='TiesAwayFromZero'):
input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
......@@ -267,12 +271,12 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
if round_type == 'TiesToEven':
round_out = np.round(
input_data.astype(compute_type) / out_scale * bnt)
quant_data = np.clip(round_out, -bnt - 1, bnt)
self.attrs['round_type'] = 0
else:
round_out = round_c(
quant_data = round_c(
input_data.astype(compute_type) / out_scale * bnt)
self.attrs['round_type'] = 1
quant_data = np.clip(round_out, -bnt - 1, bnt)
if dequantize:
output_data = (quant_data * out_scale / bnt).astype(dtype)
self.op_type = 'fake_quantize_dequantize_moving_average_abs_max'
......@@ -307,10 +311,9 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
np.random.random)
def test_fake_quantize_moving_average_abs_max_round1(self):
self._fake_quantize_moving_average_abs_max(
np.float32, (8, 16, 7, 7),
np.random.random,
round_type='TiesAwayFromZero')
self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7),
np.random.random,
round_type='TiesToEven')
def test_fake_quantize_dequantize_moving_average_abs_max(self):
self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7),
......@@ -329,17 +332,17 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
dtype,
input_shape,
distribution,
round_type='TiesToEven'):
round_type='TiesAwayFromZero'):
input_data = distribution(input_shape).astype(dtype)
scale = np.max(np.abs(input_data)).astype(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
if round_type == 'TiesToEven':
round_out = np.round(input_data / scale * bnt)
output_data = np.clip(round_out, -bnt - 1, bnt) * scale / bnt
self.attrs['round_type'] = 0
else:
round_out = round_c(input_data / scale * bnt)
output_data = round_c(input_data / scale * bnt) * scale / bnt
self.attrs['round_type'] = 1
output_data = np.clip(round_out, -bnt - 1, bnt) * scale / bnt
self.inputs = {'X': input_data}
self.outputs = {
'Out': output_data,
......@@ -357,7 +360,7 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
def test_fake_quantize_dequantize_abs_max_round1(self):
self._fake_quantize_dequantize_abs_max(np.float32, (124, 240),
np.random.random,
round_type='TiesAwayFromZero')
round_type='TiesToEven')
class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
......@@ -382,11 +385,13 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
if round_type == 'TiesToEven':
round_out = np.round(bnt * output_data / scale_broadcast)
output_data = np.clip(round_out, -bnt - 1,
bnt) * scale_broadcast / bnt
self.attrs['round_type'] = 0
else:
round_out = round_c(bnt * output_data / scale_broadcast)
output_data = round_c(
bnt * output_data / scale_broadcast) * scale_broadcast / bnt
self.attrs['round_type'] = 1
output_data = np.clip(round_out, -bnt - 1, bnt) * scale_broadcast / bnt
if quant_axis == 1:
scale_broadcast = np.transpose(scale_broadcast,
(1, ) + compute_axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册