未验证 提交 72973d5a 编写于 作者: Z zhouzj 提交者: GitHub

[clean fluid api] Move fluid/contrib/slim and remove fluid api. (#48717)

上级 a186e60d
......@@ -119,7 +119,7 @@ if(WITH_TESTING)
add_subdirectory(paddle/tests)
add_subdirectory(paddle/fluid/tests)
add_subdirectory(paddle/fluid/contrib/tests)
add_subdirectory(paddle/fluid/contrib/slim/tests)
add_subdirectory(paddle/static/quantization/tests)
endif()
if(NOT WITH_SETUP_INSTALL)
......
......@@ -1617,9 +1617,7 @@ class Engine:
fetch_vars = self._fetch_vars["predict"]['outputs']
dist_main_prog = self._dist_main_progs["predict"][self._cur_rank]
if self._strategy.qat.enable and self._strategy.qat.onnx_format:
from paddle.fluid.contrib.slim.quantization import (
QuantWeightPass,
)
from paddle.static.quantization import QuantWeightPass
self._logger.info("export quantized model.")
self._logger.info(
......
......@@ -18,14 +18,14 @@ import numpy as np
import paddle
from paddle.fluid import core, framework
from paddle.fluid.contrib.slim.quantization import (
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.static.quantization import (
AddQuantDequantForInferencePass,
AddQuantDequantPassV2,
OutScaleForTrainingPass,
QuantizationTransformPassV2,
utils,
)
from paddle.fluid.dygraph.parallel import ParallelEnv
from ..auto_parallel.converter import Converter
from ..auto_parallel.dist_attribute import (
......
......@@ -18,9 +18,6 @@ from . import memory_usage_calc
from .memory_usage_calc import *
from . import op_frequence
from .op_frequence import *
from . import quantize
from .quantize import *
from . import slim
from . import extend_optimizer
from .extend_optimizer import *
from . import model_stat
......@@ -36,7 +33,6 @@ __all__ = []
__all__ += memory_usage_calc.__all__
__all__ += op_frequence.__all__
__all__ += quantize.__all__
__all__ += extend_optimizer.__all__
__all__ += ['mixed_precision']
__all__ += layers.__all__
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import numpy as np
from paddle.fluid.framework import (
default_main_program,
default_startup_program,
program_guard,
)
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import unique_name
from paddle.fluid import core
from paddle.fluid.initializer import Constant
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.nn import autoincreased_step_counter
from paddle.fluid.framework import Variable
from paddle.fluid.executor import global_scope
__all__ = ['QuantizeTranspiler']
_QUANTIZABLE_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul']
def _quantized_var_name(var_name):
"""
Return quantized variable name for the input `var_name`.
"""
return "%s.quantized" % (var_name)
def _dequantized_var_name(var_name):
"""
Return dequantized variable name for the input `var_name`.
"""
return "%s.dequantized" % (var_name)
def _quantized_scale_name(var_name):
"""
Return quantized variable name for the input `var_name`.
"""
return "%s.scale" % (var_name)
def _original_var_name(var_name):
"""
Return the original variable name.
"""
if var_name.endswith('.quantized.dequantized'):
return var_name[: -len('.quantized.dequantized')]
if var_name.endswith('.quantized'):
return var_name[: -len('.quantized')]
if var_name.endswith('.dequantized'):
return var_name[: -len('.dequantized')]
if var_name.endswith('.scale'):
return var_name[: -len('.scale')]
else:
return var_name
def _is_float(v):
return isinstance(v, float) or isinstance(v, np.float32)
def quant(x, scale, num_bits):
y = np.round(x / scale * ((1 << (num_bits - 1)) - 1))
return y
class QuantizeTranspiler:
def __init__(
self,
weight_bits=8,
activation_bits=8,
activation_quantize_type='abs_max',
weight_quantize_type='abs_max',
window_size=10000,
moving_rate=0.9,
):
"""
Convert and rewrite the fluid Program according to weight and
activation quantization type.
Args:
weight_bits (int): quantization bit number for weights,
the bias is not quantized.
activation_bits (int): quantization bit number for activation.
activation_quantize_type (str): quantization type for activation,
now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode,
the quantization scale will be calculated dynamically each step
in both training and testing period. If use 'range_abs_max',
a static quantization scale will be calculated during training
and used in inference.
weight_quantize_type (str): quantization type for weights,
support 'abs_max'. The 'range_abs_max' usually is not used for
weight, since weights are fixed once the model is well trained.
window_size (int): the window size for 'range_abs_max' quantization.
Examples:
.. code-block:: python
# the original program will be rewrite, if you don't want to
# change it, please clone at first.
# quantize_program = program.clone()
t = fluid.QuantizeTranspiler()
t.transpile(quantize_program)
"""
self.weight_bits = weight_bits
self.activation_bits = activation_bits
quant_type = ['abs_max', 'range_abs_max', 'moving_average_abs_max']
if weight_quantize_type not in quant_type:
raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be ",
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(weight_quantize_type),
)
if activation_quantize_type not in quant_type:
raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be ",
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(activation_quantize_type),
)
self.weight_quantize_type = weight_quantize_type
self.activation_quantize_type = activation_quantize_type
self.window_size = window_size
self.moving_rate = moving_rate
self.helper = LayerHelper(self.__class__.__name__)
self.fake_quant_op_types = [
'fake_quantize_abs_max',
'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max',
]
self.fake_dequant_op_types = ['fake_dequantize_max_abs']
self.is_test = None
self.global_step = None
def training_transpile(self, program=None, startup_program=None):
"""Rewrites a training input program in place for simulated
quantization. Insert fake quantization and de-quantization ops into
program to simulate the error introduced by quantization. And change
the gradient ops' input by using the faked quantization weights and
activation. Since the program is transformed in place, the graph
connection will change.
Args:
program (Program): the input program to be transpile.
"""
self.is_test = False
program = default_main_program() if program is None else program
startup_program = (
default_startup_program()
if startup_program is None
else startup_program
)
# marked the variable which has been quantized and dequantized.
dequanted_vars = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
grad_op_types = ['%s_grad' % (type) for type in _QUANTIZABLE_OP_TYPES]
params = [p.name for p in program.global_block().iter_parameters()]
def _transpile_forward(block, op):
idx = block.ops.index(op)
block_id = block.idx
# insert quant op and dequant op
for name in op.input_arg_names:
# if share input between ops
if name in dequanted_vars[block_id]:
dequant_var = dequanted_vars[block_id][name]
else:
var = block.var(name)
quant_bits = (
self.weight_bits
if var.name in params
else self.activation_bits
)
quant_type = (
self.weight_quantize_type
if var.name in params
else self.activation_quantize_type
)
quant_var, scale_var = self._insert_quant_op(
block, idx, var, quant_bits, quant_type
)
dequant_var = self._insert_dequant_op(
block, idx + 1, quant_var, scale_var, quant_bits
)
dequanted_vars[block_id][name] = dequant_var
# rename the forward op inputs
op._rename_input(name, dequant_var.name)
def _transpile_backward(block, op):
block_id = block.idx
no_dequanted_input_vars = True
for name in op.input_arg_names:
if name in dequanted_vars[block_id]:
dequant_var = dequanted_vars[block_id][name]
op._rename_input(name, dequant_var.name)
no_dequanted_input_vars = False
if no_dequanted_input_vars:
raise ValueError(
"There is no dequanted inputs for op %s." % (op.type)
)
with program_guard(program, startup_program):
self._create_global_step()
for block in program.blocks:
ops = list(block.ops)
block_id = block.idx
for op in ops:
# rewrite the forward ProgramDes
if op.type in _QUANTIZABLE_OP_TYPES:
_transpile_forward(block, op)
# rename the backward op inputs
if op.type in grad_op_types:
_transpile_backward(block, op)
def _create_global_step(self):
if (
self.weight_quantize_type == 'range_abs_max'
or self.activation_quantize_type == 'range_abs_max'
):
self.global_step = autoincreased_step_counter()
def freeze_program(self, program, place, scope=None):
"""Freeze input training program for inference.
Args:
program (Program): the input program to be transpile.
"""
self.is_test = True
scope = global_scope() if scope is None else scope
program = default_main_program() if program is None else program
persistable_vars = [
v.name
for v in filter(lambda var: var.persistable, program.list_vars())
]
op_in_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
op_out_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
var_scale_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
def _remove_fake_quant_and_dequant_op(block, op):
idx = block.ops.index(op)
block_id = block.idx
k = op.output('Out')[0]
v = op.input('X')[0]
if v not in op_in_rename_map[block_id]:
op_in_rename_map[block_id][k] = v
else:
op_in_rename_map[block_id][k] = op_in_rename_map[block_id][v]
block._remove_op(idx)
def _insert_post_dequant_op(block, op):
idx = block.ops.index(op)
block_id = block.idx
max_range = None
scale_var = None
for name in op.input_arg_names:
# rename input name of the op to the input name of last op which has be removed
if name in op_in_rename_map[block_id]:
op._rename_input(name, op_in_rename_map[block_id][name])
scale_v = var_scale_map[block_id][_original_var_name(name)]
if _original_var_name(name) in persistable_vars:
param_range = (1 << (self.weight_bits - 1)) - 1
act_range = (1 << (self.activation_bits - 1)) - 1
assert _is_float(scale_v)
max_range = param_range * act_range / scale_v
else:
assert isinstance(scale_v, Variable)
scale_var = scale_v
if len(op.output_arg_names) != 1:
raise ValueError(
"Only support one output, but op %s has"
" more than one output." % (op.type)
)
out_var = block.var(op.output_arg_names[0])
dequant_var = block.create_var(
name=_dequantized_var_name(out_var.name),
type=out_var.type,
shape=out_var.shape,
dtype=out_var.dtype,
)
# insert fake_dequantize_op
dequant_op = block._insert_op(
idx + 1,
type="fake_dequantize_max_abs",
attrs={'max_range': float(max_range)},
inputs={"X": out_var, 'Scale': scale_var},
outputs={"Out": dequant_var},
)
op_out_rename_map[block_id][out_var.name] = dequant_var.name
return dequant_var
def _load_var(name):
return np.array(scope.find_var(name).get_tensor())
def _restore_var(name, arr):
t = scope.find_var(name).get_tensor()
t.set(arr, place)
for block in program.blocks:
ops = list(block.ops)
block_id = block.idx
for op in ops:
op_type = op.type
# insert dequant_op after fc/conv, need to rename
# input of the followed ops(of fc/conv) to the dquant_op
for name in op.input_arg_names:
if name in op_out_rename_map[block_id]:
op._rename_input(
name, op_out_rename_map[block_id][name]
)
if op_type in self.fake_quant_op_types:
in_arg_name = op.input('X')[0]
if in_arg_name in persistable_vars:
if self.weight_quantize_type == 'abs_max':
param = _load_var(in_arg_name)
scale_v = np.max(np.abs(param))
else:
scale_v = _load_var(op.output('OutScale')[0])
var_scale_map[block_id][in_arg_name] = scale_v
else:
scale_v = block.var(op.output('OutScale')[0])
var_scale_map[block_id][in_arg_name] = scale_v
if in_arg_name in persistable_vars:
_remove_fake_quant_and_dequant_op(block, op)
# quantize weight and restore
param_t = _load_var(in_arg_name)
param_q_t = quant(param_t, scale_v, self.weight_bits)
_restore_var(in_arg_name, param_q_t)
if op_type in self.fake_dequant_op_types:
_remove_fake_quant_and_dequant_op(block, op)
if op_type in _QUANTIZABLE_OP_TYPES:
dequant_var = _insert_post_dequant_op(block, op)
# remove the unused var in ProgramDesc
self._remove_unused_var(program)
# program = program.clone()
def convert_to_int8(self, program, place, scope=None):
scope = global_scope() if scope is None else scope
program = default_main_program() if program is None else program
def _load_var(name):
return np.array(scope.find_var(name).get_tensor())
global_block = program.global_block()
def convert_to_int8(var):
int8_var_name = var.name + ".int8"
int8_var = global_block.create_parameter(
name=int8_var_name.encode('ascii'),
type=var.type,
dtype=core.VarDesc.VarType.INT8,
shape=var.shape,
)
tensor = _load_var(var.name)
scope.var(int8_var_name)
int8_tensor = scope.find_var(int8_var_name).get_tensor()
int8_tensor.set(tensor.astype(np.int8), place)
return int8_var
input_map = {}
for block in program.blocks:
for op in list(block.ops):
if op.type in _QUANTIZABLE_OP_TYPES:
for name in op.input_arg_names:
var = block.var(name)
if var.persistable:
if name not in input_map:
int8_var = convert_to_int8(var)
input_map[name] = int8_var.name
op._rename_input(name, input_map[name])
self._remove_unused_var(program)
def _remove_unused_var(self, program):
all_remove_vars = []
for block in program.blocks:
args = []
for op in block.ops:
args += op.input_arg_names
args += op.output_arg_names
args = list(set(args)) # vals of all left ops
var_names = block.vars.keys() # all vals
sub_block_remove_vars = []
for var in var_names:
if var not in args:
sub_block_remove_vars.append(var)
all_remove_vars.append(sub_block_remove_vars)
remove_vars = [list(set(v)) for v in all_remove_vars]
for i, block in enumerate(program.blocks):
for v in remove_vars[i]:
block._remove_var(v)
def _insert_quant_abs_max_op(self, block, idx, var, quant_bits):
"""Insert fake_quantize_abs_max op."""
quant_var = block.create_var(
name=_quantized_var_name(var.name),
type=var.type,
shape=var.shape,
dtype=var.dtype,
)
scale = block.create_var(
name=_quantized_scale_name(var.name),
type=var.type,
shape=var.shape,
dtype=var.dtype,
)
quant_op = block._insert_op(
idx,
type='fake_quantize_abs_max',
attrs={'bit_length': quant_bits},
inputs={'X': var},
outputs={'Out': quant_var, 'OutScale': scale},
)
return quant_var, scale
def _insert_quant_range_abs_max_op(self, block, idx, var, quant_bits):
"""Insert fake_quantize_range_abs_max"""
quant_var = block.create_var(
name=_quantized_var_name(var.name),
type=var.type,
shape=var.shape,
dtype=var.dtype,
)
scale = self.helper.create_parameter(
attr=ParamAttr(
name=_quantized_scale_name(var.name),
initializer=Constant(0.001),
trainable=False,
),
shape=[1],
dtype=var.dtype,
)
scale.stop_gradient = True
ins = {'X': var, 'InScale': scale}
outs = {'Out': quant_var, 'OutScale': scale}
if not self.is_test:
# A global step counter variable with type int64
scales = self.helper.create_global_variable(
name=unique_name.generate('scales'),
persistable=True,
dtype=var.dtype,
shape=[self.window_size],
)
self.helper.set_variable_initializer(
scales, initializer=Constant(value=0)
)
ins['Iter'] = self.global_step
outs['OutScales'] = scales
attrs = {
'window_size': self.window_size,
'bit_length': quant_bits,
'is_test': self.is_test,
}
quant_op = block._insert_op(
idx,
type='fake_quantize_range_abs_max',
attrs=attrs,
inputs=ins,
outputs=outs,
)
return quant_var, scale
def _insert_quant_moving_average_abs_max_op(
self, block, idx, var, quant_bits
):
"""Insert fake_quantize_moving_average_abs_max"""
quant_var = block.create_var(
name=_quantized_var_name(var.name),
type=var.type,
shape=var.shape,
dtype=var.dtype,
)
state = self.helper.create_global_variable(
name=unique_name.generate('state'),
persistable=True,
dtype=var.dtype,
shape=[1],
)
self.helper.set_variable_initializer(
state, initializer=Constant(value=1)
)
accum = self.helper.create_global_variable(
name=unique_name.generate('accum'),
persistable=True,
dtype=var.dtype,
shape=[1],
)
self.helper.set_variable_initializer(
accum, initializer=Constant(value=1)
)
scale = self.helper.create_parameter(
attr=ParamAttr(
name=_quantized_scale_name(var.name),
initializer=Constant(0.001),
trainable=False,
),
shape=[1],
dtype=var.dtype,
)
scale.stop_gradient = True
ins = {'X': var, 'InScale': scale}
outs = {'Out': quant_var, 'OutScale': scale}
if not self.is_test:
ins['InState'] = state
ins['InAccum'] = accum
outs['OutState'] = state
outs['OutAccum'] = accum
attrs = {
'bit_length': quant_bits,
'moving_rate': self.moving_rate,
'is_test': self.is_test,
}
quant_op = block._insert_op(
idx,
type='fake_quantize_moving_average_abs_max',
attrs=attrs,
inputs=ins,
outputs=outs,
)
return quant_var, scale
def _insert_quant_op(self, block, idx, var, quant_bits, quant_type):
"""
Insert fake_quantize_op
"""
if quant_type == 'abs_max':
return self._insert_quant_abs_max_op(block, idx, var, quant_bits)
elif quant_type == 'range_abs_max':
return self._insert_quant_range_abs_max_op(
block, idx, var, quant_bits
)
elif quant_type == 'moving_average_abs_max':
return self._insert_quant_moving_average_abs_max_op(
block, idx, var, quant_bits
)
def _insert_dequant_op(self, block, idx, var, scale, quant_bits):
"""
Insert fake_quantize_op
"""
dequant_var = block.create_var(
name=_dequantized_var_name(var.name),
type=var.type,
shape=var.shape,
dtype=var.dtype,
)
# insert fake_dequantize_op
max_range = (1 << (quant_bits - 1)) - 1
dequant_op = block._insert_op(
idx,
type="fake_dequantize_max_abs",
attrs={'max_range': float(max_range)},
inputs={"X": var, 'Scale': scale},
outputs={"Out": dequant_var},
)
return dequant_var
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import quantization_pass
from .quantization_pass import *
from . import quant_int8_mkldnn_pass
from .quant_int8_mkldnn_pass import *
from . import quant2_int8_mkldnn_pass
from .quant2_int8_mkldnn_pass import *
from . import post_training_quantization
from .post_training_quantization import *
from . import imperative
from .imperative import *
__all__ = []
__all__ += quantization_pass.__all__
__all__ += quant_int8_mkldnn_pass.__all__
__all__ += quant2_int8_mkldnn_pass.__all__
__all__ += post_training_quantization.__all__
__all__ += imperative.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
import numpy as np
from .... import core
from ....framework import Program, Operator, Variable, program_guard
from ....executor import global_scope
from .... import unique_name
from ....layer_helper import LayerHelper
from ....param_attr import ParamAttr
from ....initializer import Constant
from ....log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
def find_next_ops(block, var_name):
"""
Find all followed ops for the input variable.
"""
res_ops = []
for op in block.ops:
if var_name in op.input_arg_names:
res_ops.append(op)
return res_ops
def load_variable_data(scope, var_name):
'''
Load variable value from scope
'''
var_node = scope.find_var(var_name)
assert var_node is not None, "Cannot find " + var_name + " in scope."
return np.array(var_node.get_tensor())
class QuantizeTranspilerV2:
def __init__(
self,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max',
quantizable_op_type=[
'conv2d',
'depthwise_conv2d',
'mul',
],
skip_pattern=['skip_quant'],
):
"""
Apply fake quant for the quantized ops.
Args:
weight_bits(int): the bit of quantized weight.
activation_bits(int): the bit of quantized activation.
weight_quantize_type(str): the quantization type for weight.
Only support to be 'abs_max' and 'channel_wise_abs_max'.
activation_quantize_type(str): the quantization type for activation.
Only support to be 'abs_max' and 'moving_average_abs_max'.
quantizable_op_type(str): set the op type for quantization.
skip_pattern(str|list): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized.
"""
self._weight_bits = weight_bits
self._activation_bits = activation_bits
assert activation_quantize_type in [
"abs_max",
"moving_average_abs_max",
], (
"activation_quantize_type should be abs_max "
"or moving_average_abs_max for now."
)
assert weight_quantize_type in [
"abs_max",
"channel_wise_abs_max",
], "weight_quantize_type should be abs_max or channel_wise_abs_max."
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
for op_type in quantizable_op_type:
assert op_type in [
'conv2d',
'depthwise_conv2d',
'mul',
], "Quantize op should be ['conv2d', 'depthwise_conv2d', 'mul']"
self._quantizable_ops = quantizable_op_type
self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops
]
self._skip_pattern = skip_pattern
self._helper = LayerHelper(self.__class__.__name__)
self._moving_rate = 0.9
self._out_ch_axis1_ops = ['conv2d_transpose', 'mul', 'matmul']
def apply(self, program, startup_program, is_test=False):
"""
Apply quantization to fluid Program.
Args:
program(Program): the train or test program to be quantized.
startup_program(Program): the corresponding startup_program.
is_test(bool): Whethe the program is used for test.
Returns:
None
"""
assert isinstance(
program, Program
), "program must be the instance of Program"
assert isinstance(
startup_program, Program
), "startup_program must be the instance of Program"
var_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
with program_guard(program, startup_program):
for block in program.blocks:
ops = list(block.ops)
for op in ops:
if op.type in self._quantizable_ops and (
not self._is_skip_quant(op)
):
self._transform_forward(
block, op, var_rename_map, is_test
)
for block in program.blocks:
ops = list(block.ops)
for op in ops:
if op.type in self._quantizable_grad_ops and (
not self._is_skip_quant(op)
):
self._transform_backward(block, op, var_rename_map)
def convert(self, test_program, scope=None):
"""
Convert the test program.
Get the out scale from the moving_average_abs_max_scale op and save the
out scale into the quantized op.
Args:
test_program(Program): the test program to be converted.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
"""
scope = global_scope() if scope is None else scope
for block in test_program.blocks:
for op in block.ops:
if (
op.has_attr("quantization_type")
and op.attr("quantization_type") == "qat_with_weight"
):
# quant op -> var1 -> fake op -> var2
assert len(op.output_arg_names) == 1
var1_name = op.output_arg_names[0]
fake_ops = find_next_ops(block, var1_name)
assert len(fake_ops) == 1
fake_op = fake_ops[0]
assert fake_op.type == "moving_average_abs_max_scale"
out_scale_name = fake_op.output("OutScale")
out_threshold = load_variable_data(scope, out_scale_name[0])
op._set_attr("out_threshold", float(out_threshold))
var2_name = fake_op.output("Out")[0]
op._rename_output(var1_name, var2_name)
fake_op._rename_output(var2_name, var1_name)
def _transform_forward(self, block, op, var_rename_map, is_test):
"""
Insert fake quant op before the target ops.
"""
op._set_attr("quantization_type", "qat_with_weight")
# insert fake quant op before the quantized op
for in_name in op.input_arg_names:
block_id = block.idx
idx = block.ops.index(op)
if in_name in var_rename_map[block_id]:
new_in_name = var_rename_map[block_id][in_name]
else:
in_var = block.var(in_name)
target_dtype = [
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
]
if in_var.dtype not in target_dtype:
continue
quant_bits = (
self._weight_bits
if in_var.persistable
else self._activation_bits
)
quant_type = (
self._weight_quantize_type
if in_var.persistable
else self._activation_quantize_type
)
if quant_type == "abs_max":
new_var = self._insert_abs_max_fq_op(
block, idx, in_var, quant_bits
)
elif quant_type == "moving_average_abs_max":
new_var = self._insert_ma_abs_max_fq_op(
block, idx, in_var, quant_bits, is_test
)
elif quant_type == "channel_wise_abs_max":
ch_axis = 1 if op.type in self._out_ch_axis1_ops else 0
new_var = self._insert_pc_abs_max_fq_op(
block, idx, in_var, quant_bits, ch_axis
)
else:
_logger.error(
"Don't support the quant_type: %s" % quant_type
)
continue
new_in_name = new_var.name
var_rename_map[block_id][in_name] = new_in_name
op._rename_input(in_name, new_in_name)
# insert out scale op followed the quantized op
for out_name in op.output_arg_names:
next_ops = find_next_ops(block, out_name)
idx = block.ops.index(op)
out_var = block.var(out_name)
new_out_var = self._insert_ma_abs_max_scale_op(
block, idx + 1, out_var, is_test, True
)
for next_op in next_ops:
if "_grad" not in next_op.type:
next_op._rename_input(out_name, new_out_var.name)
def _is_skip_quant(self, op):
"""
Analyse whether the op should skip quantization or not.
"""
user_skipped = False
if isinstance(self._skip_pattern, list):
user_skipped = op.has_attr("op_namescope") and any(
pattern in op.attr("op_namescope")
for pattern in self._skip_pattern
)
elif isinstance(self._skip_pattern, str):
user_skipped = (
op.has_attr("op_namescope")
and op.attr("op_namescope").find(self._skip_pattern) != -1
)
return user_skipped
def _transform_backward(self, block, op, var_rename_map):
"""
Update the backword of the target ops.
Note: for the grad ops, only rename the input, skip rename the output.
"""
block_id = block.idx
no_dequanted_input_vars = True
for name in op.input_arg_names:
if name in var_rename_map[block_id]:
new_var_name = var_rename_map[block_id][name]
op._rename_input(name, new_var_name)
no_dequanted_input_vars = False
if no_dequanted_input_vars:
raise ValueError(
"There is no dequanted inputs for op %s." % (op.type)
)
def _insert_abs_max_fq_op(self, block, idx, in_var, quant_bits):
"""
Inset abs max fake quant op.
"""
quant_dequant_var = block.create_var(
type=in_var.type,
name="{}.quant_dequant".format(in_var.name),
shape=in_var.shape,
dtype=in_var.dtype,
)
scale_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.quant_dequant.scale".format(in_var.name),
initializer=Constant(0.0),
trainable=False,
),
shape=[1],
dtype=in_var.dtype,
)
scale_var.stop_gradient = True
inputs = {'X': in_var}
outputs = {'Out': quant_dequant_var, 'OutScale': scale_var}
attrs = {'bit_length': quant_bits}
block._insert_op(
idx,
type='fake_quantize_dequantize_abs_max',
attrs=attrs,
inputs=inputs,
outputs=outputs,
)
return quant_dequant_var
def _insert_ma_abs_max_fq_op(self, block, idx, in_var, quant_bits, is_test):
"""
Insert moving average abs max fake quant op.
"""
quant_dequant_var = block.create_var(
type=in_var.type,
name="{}.quant_dequant".format(in_var.name),
shape=in_var.shape,
dtype=in_var.dtype,
)
scale_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.quant_dequant.scale".format(in_var.name),
initializer=Constant(0.0),
trainable=False,
),
shape=[1],
dtype=in_var.dtype,
)
scale_var.stop_gradient = True
if not is_test:
state_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.quant_dequant.state".format(in_var.name),
initializer=Constant(0),
trainable=False,
),
shape=[1],
dtype=in_var.dtype,
)
state_var.stop_gradient = True
accum_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.quant_dequant.accum".format(in_var.name),
initializer=Constant(0),
trainable=False,
),
shape=[1],
dtype=in_var.dtype,
)
accum_var.stop_gradient = True
attrs = {
'moving_rate': self._moving_rate,
'bit_length': quant_bits,
'is_test': is_test,
}
inputs = {'X': in_var, 'InScale': scale_var}
outputs = {'Out': quant_dequant_var, 'OutScale': scale_var}
if not is_test:
inputs['InState'] = state_var
inputs['InAccum'] = accum_var
outputs['OutState'] = state_var
outputs['OutAccum'] = accum_var
block._insert_op(
idx,
type='fake_quantize_dequantize_moving_average_abs_max',
attrs=attrs,
inputs=inputs,
outputs=outputs,
)
return quant_dequant_var
def _insert_pc_abs_max_fq_op(self, block, idx, in_var, quant_bits, ch_axis):
"""
Insert per channel abs max fake quant op.
"""
quant_dequant_var = block.create_var(
type=in_var.type,
name="{}.quant_dequant".format(in_var.name),
shape=in_var.shape,
dtype=in_var.dtype,
)
scale_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.quant_dequant.scale".format(in_var.name),
initializer=Constant(0.0),
trainable=False,
),
shape=[in_var.shape[ch_axis]],
dtype=in_var.dtype,
)
scale_var.stop_gradient = True
inputs = {'X': in_var}
outputs = {'Out': quant_dequant_var, 'OutScale': scale_var}
attrs = {'bit_length': quant_bits, 'quant_axis': ch_axis}
block._insert_op(
idx,
type='fake_channel_wise_quantize_dequantize_abs_max',
attrs=attrs,
inputs=inputs,
outputs=outputs,
)
return quant_dequant_var
def _insert_ma_abs_max_scale_op(
self, block, idx, in_var, is_test, has_out_var=False
):
"""
Insert moving average abs max scale op.
"""
scale_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.outscale.scale".format(in_var.name),
initializer=Constant(0.0),
trainable=False,
),
shape=[1],
dtype=in_var.dtype,
)
scale_var.stop_gradient = True
attrs = {'moving_rate': self._moving_rate, 'is_test': is_test}
inputs = {'X': in_var}
outputs = {'OutScale': scale_var}
if not is_test:
state_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.outscale.state".format(in_var.name),
initializer=Constant(0),
trainable=False,
),
shape=[1],
dtype=in_var.dtype,
)
state_var.stop_gradient = True
accum_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.outscale.accum".format(in_var.name),
initializer=Constant(0),
trainable=False,
),
shape=[1],
dtype=in_var.dtype,
)
accum_var.stop_gradient = True
inputs['InState'] = state_var
inputs['InAccum'] = accum_var
outputs['OutState'] = state_var
outputs['OutAccum'] = accum_var
if has_out_var:
out_var = block.create_var(
type=in_var.type,
name="{}.tmp".format(in_var.name),
shape=in_var.shape,
dtype=in_var.dtype,
)
outputs['Out'] = out_var
block._insert_op(
idx,
type='moving_average_abs_max_scale',
attrs=attrs,
inputs=inputs,
outputs=outputs,
)
if has_out_var:
return out_var
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import os
import unittest
import random
import numpy as np
import paddle.fluid as fluid
import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization.quantize_transpiler_v2 import (
QuantizeTranspilerV2,
)
from paddle.fluid import core
paddle.enable_static()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CPU_NUM"] = "1"
def conv_net(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
pool_type='max',
act="relu",
)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
pool_type='avg',
act="relu",
)
with fluid.name_scope("skip_quant"):
hidden = fluid.layers.fc(input=conv_pool_1, size=100, act='relu')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
)
avg_loss = paddle.mean(loss)
return avg_loss
class TestQuantizeProgramPass(unittest.TestCase):
def quantize_program(
self,
use_cuda,
seed,
activation_quant_type='abs_max',
weight_quant_type='abs_max',
for_ci=False,
):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32'
)
label = fluid.layers.data(
name='label', shape=[1], dtype='int64'
)
loss = conv_net(img, label)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.0001)
opt.minimize(loss)
return [img, label], loss
random.seed(0)
np.random.seed(0)
# 1 Define program
train_program = fluid.Program()
startup_program = fluid.Program()
test_program = fluid.Program()
feeds, loss = build_program(train_program, startup_program, False)
build_program(test_program, startup_program, True)
test_program = test_program.clone(for_test=True)
if not for_ci:
train_graph = IrGraph(
core.Graph(train_program.desc), for_test=False
)
train_graph.draw('.', 'train_program_1')
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
test_graph.draw('.', 'test_program_1')
# 2 Apply quantization
qt = QuantizeTranspilerV2(
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type,
)
qt.apply(train_program, startup_program, is_test=False)
qt.apply(test_program, startup_program, is_test=True)
# 3 Train
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
with fluid.scope_guard(scope):
exe.run(startup_program)
if not for_ci:
train_graph = IrGraph(
core.Graph(train_program.desc), for_test=False
)
train_graph.draw('.', 'train_program_2')
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
test_graph.draw('.', 'test_program_2')
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
binary = fluid.CompiledProgram(train_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy
)
iters = 5
batch_size = 8
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size
)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
with fluid.scope_guard(scope):
for idx in range(iters):
data = next(train_reader())
loss_v = exe.run(
binary, feed=feeder.feed(data), fetch_list=[loss]
)
if not for_ci and idx % 20 == 0:
print('{}: {}'.format('loss', np.mean(loss_v)))
print('{}: {}'.format('loss', np.mean(loss_v)))
# 4 Convert
qt.convert(test_program, scope)
if not for_ci:
with fluid.scope_guard(scope):
fluid.io.save_inference_model(
'./infer_model',
['image', 'label'],
[loss],
exe,
test_program,
clip_extra=True,
)
def test_gpu_1(self):
if fluid.core.is_compiled_with_cuda():
self.quantize_program(
use_cuda=True,
seed=1,
activation_quant_type='abs_max',
weight_quant_type='abs_max',
for_ci=True,
)
def test_gpu_2(self):
if fluid.core.is_compiled_with_cuda():
self.quantize_program(
use_cuda=True,
seed=1,
activation_quant_type='moving_average_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True,
)
def test_cpu_1(self):
self.quantize_program(
use_cuda=False,
seed=2,
activation_quant_type='abs_max',
weight_quant_type='abs_max',
for_ci=True,
)
def test_cpu_2(self):
self.quantize_program(
use_cuda=False,
seed=2,
activation_quant_type='moving_average_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True,
)
if __name__ == '__main__':
unittest.main()
......@@ -25,5 +25,4 @@ set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120)
if(APPLE)
set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 300)
set_tests_properties(test_quantize_transpiler PROPERTIES TIMEOUT 300)
endif()
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import numpy as np
import unittest
import paddle
import paddle.fluid as fluid
from paddle.fluid.contrib.quantize.quantize_transpiler import _original_var_name
from paddle.fluid.contrib.quantize.quantize_transpiler import QuantizeTranspiler
import paddle
paddle.enable_static()
def linear_fc(num):
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in range(num):
hidden = fluid.layers.fc(hidden, size=128, act='relu')
loss = paddle.nn.functional.cross_entropy(
input=hidden, label=label, reduction='none', use_softmax=False
)
loss = paddle.mean(loss)
return loss
def residual_block(num):
def conv_bn_layer(
input, ch_out, filter_size, stride, padding, act='relu', bias_attr=False
):
tmp = paddle.static.nn.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr,
)
return paddle.static.nn.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in range(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = paddle.nn.functional.relu(paddle.add(x=conv, y=short))
fc = fluid.layers.fc(input=hidden, size=10)
loss = paddle.nn.functional.cross_entropy(
input=fc, label=label, reduction='none', use_softmax=False
)
loss = paddle.mean(loss)
return loss
def conv_net(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu",
)
conv_pool_1 = paddle.static.nn.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu",
)
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
loss = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
)
avg_loss = paddle.mean(loss)
return avg_loss
class TestQuantizeTranspiler(unittest.TestCase):
def setUp(self):
# since quant_op and dequant_op is not ready, use cos and sin for test
self.weight_quant_op_type = 'fake_quantize_abs_max'
self.dequant_op_type = 'fake_dequantize_max_abs'
self.quantizable_op_and_inputs = {
'conv2d': ['Input', 'Filter'],
'depthwise_conv2d': ['Input', 'Filter'],
'mul': ['X', 'Y'],
}
self.quantizable_op_grad_and_inputs = {
'conv2d_grad': ['Input', 'Filter'],
'depthwise_conv2d_grad': ['Input', 'Filter'],
'mul_grad': ['X', 'Y'],
}
def check_program(self, program):
quantized_ops = {}
persistable_vars = [
v.name
for v in filter(lambda var: var.persistable, program.list_vars())
]
for block in program.blocks:
for idx, op in enumerate(block.ops):
# check forward
if op.type in self.quantizable_op_and_inputs:
for i, arg_name in enumerate(op.input_arg_names):
quant_op_type = (
self.weight_quant_op_type
if _original_var_name(arg_name) in persistable_vars
else self.act_quant_op_type
)
self.assertTrue(
arg_name.endswith('.quantized.dequantized')
)
if arg_name not in quantized_ops:
self.assertEqual(
block.ops[idx - 2 * i - 1].type,
self.dequant_op_type,
)
self.assertEqual(
block.ops[idx - 2 * i - 2].type, quant_op_type
)
quantized_ops[arg_name] = block.ops[idx - 2 * i - 2]
else:
op_idx = block.ops.index(quantized_ops[arg_name])
self.assertLess(op_idx, idx)
# check backward
if op.type in self.quantizable_op_grad_and_inputs:
for pname in self.quantizable_op_grad_and_inputs[op.type]:
arg_name = op.input(pname)[0]
self.assertTrue(
arg_name.endswith('.quantized.dequantized')
)
self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
t = QuantizeTranspiler(activation_quantize_type=quant_type)
t.training_transpile(main)
self.check_program(main)
def test_linear_fc_quant_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.linear_fc_quant('abs_max')
def test_linear_fc_quant_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.linear_fc_quant('range_abs_max')
def residual_block_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
t = QuantizeTranspiler(activation_quantize_type=quant_type)
t.training_transpile(main)
self.check_program(main)
def test_residual_block_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.residual_block_quant('abs_max')
def test_residual_block_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.residual_block_quant('range_abs_max')
def freeze_program(self, use_cuda, seed):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32'
)
label = fluid.layers.data(
name='label', shape=[1], dtype='int64'
)
loss = conv_net(img, label)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
return [img, label], loss
main = fluid.Program()
startup = fluid.Program()
test_program = fluid.Program()
import random
random.seed(0)
np.random.seed(0)
feeds, loss = build_program(main, startup, False)
build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True)
quant_type = 'range_abs_max' # 'range_abs_max' or 'abs_max'
quant_transpiler = QuantizeTranspiler(
activation_quantize_type=quant_type
)
quant_transpiler.training_transpile(main, startup)
quant_transpiler.training_transpile(test_program, startup)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
iters = 5
batch_size = 8
class_num = 10
exe.run(startup)
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
batch_size=batch_size,
)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size
)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
with fluid.program_guard(main):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(
program=main, feed=feeder.feed(data), fetch_list=[loss]
)
with fluid.program_guard(test_program):
test_data = next(test_reader())
w_var = fluid.framework._get_var(
'conv2d_1.w_0.quantized', test_program
)
# Testing during training
test_loss1, w_quant = exe.run(
program=test_program,
feed=feeder.feed(test_data),
fetch_list=[loss, w_var],
)
# Freeze program for inference, but the weight of fc/conv is still float type.
quant_transpiler.freeze_program(test_program, place)
(test_loss2,) = exe.run(
program=test_program,
feed=feeder.feed(test_data),
fetch_list=[loss],
)
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
w_freeze = np.array(
fluid.global_scope().find_var('conv2d_1.w_0').get_tensor()
)
# fail: -432.0 != -433.0, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
# Convert parameter to 8-bit.
quant_transpiler.convert_to_int8(test_program, place)
# Save the 8-bit parameter and model file.
fluid.io.save_inference_model(
'model_8bit',
['image', 'label'],
[loss],
exe,
test_program,
clip_extra=True,
)
# Test whether the 8-bit parameter and model file can be loaded successfully.
[infer, feed, fetch] = fluid.io.load_inference_model(
'model_8bit', exe
)
# Check the loaded 8-bit weight.
w_8bit = np.array(
fluid.global_scope().find_var('conv2d_1.w_0.int8').get_tensor()
)
self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
def not_test_freeze_program_cuda(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_program(True, seed=1)
def not_test_freeze_program_cpu(self):
with fluid.unique_name.guard():
self.freeze_program(False, seed=2)
if __name__ == '__main__':
unittest.main()
......@@ -23,7 +23,7 @@ import paddle.distributed.fleet as fleet
import paddle.fluid as fluid
import paddle.nn as nn
from paddle.distributed.utils.launch_utils import find_free_ports, get_cluster
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.quantization import ImperativeQuantAware
def set_random_seed(seed, dp_id, rank_id):
......
......@@ -20,10 +20,6 @@ import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.contrib.slim.quantization import (
QuantizationFreezePass,
QuantizationTransformPass,
)
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import (
IrGraph,
......@@ -32,6 +28,10 @@ from paddle.fluid.framework import (
convert_np_dtype_to_dtype_,
)
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.static.quantization import (
QuantizationFreezePass,
QuantizationTransformPass,
)
class TensorConfig:
......
......@@ -21,16 +21,16 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, Variable, core
from paddle.fluid.contrib.slim.quantization import (
from paddle.fluid.core import AnalysisConfig, create_paddle_predictor
from paddle.fluid.framework import IrGraph
from paddle.fluid.io import append_fetch_ops, prepend_feed_ops
from paddle.static.quantization import (
AddQuantDequantPass,
OutScaleForInferencePass,
OutScaleForTrainingPass,
QuantizationFreezePass,
QuantizationTransformPass,
)
from paddle.fluid.core import AnalysisConfig, create_paddle_predictor
from paddle.fluid.framework import IrGraph
from paddle.fluid.io import append_fetch_ops, prepend_feed_ops
class QuantDequantTest(unittest.TestCase):
......
......@@ -18,9 +18,9 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.framework import IrGraph, Program, program_guard
from paddle.fluid.tests.unittests.op_test import OpTestTool
from paddle.static.quantization import QuantizationTransformPass
paddle.enable_static()
......
......@@ -24,7 +24,7 @@ from PIL import Image
import paddle
import paddle.fluid as fluid
from paddle.dataset.common import download
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.static.quantization import PostTrainingQuantization
paddle.enable_static()
......
......@@ -12,40 +12,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..fluid.contrib.slim.quantization.imperative.ptq_config import (
from .imperative.ptq_config import (
PTQConfig,
default_ptq_config,
)
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from .imperative.ptq_quantizer import (
BaseQuantizer,
)
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from .imperative.ptq_quantizer import (
AbsmaxQuantizer,
)
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from .imperative.ptq_quantizer import (
PerChannelAbsmaxQuantizer,
)
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from .imperative.ptq_quantizer import (
KLQuantizer,
)
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from .imperative.ptq_quantizer import (
HistQuantizer,
)
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from .imperative.ptq_quantizer import (
SUPPORT_ACT_QUANTIZERS,
)
from ..fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
from .imperative.ptq_quantizer import (
SUPPORT_WT_QUANTIZERS,
)
from ..fluid.contrib.slim.quantization.imperative.ptq_registry import (
from .imperative.ptq_registry import (
PTQRegistry,
)
from ..fluid.contrib.slim.quantization.imperative.ptq import ImperativePTQ
from ..fluid.contrib.slim.quantization.imperative.qat import (
from .imperative.ptq import (
ImperativePTQ,
)
from .imperative.qat import (
ImperativeQuantAware,
)
from .config import QuantConfig
from .base_quanter import BaseQuanter
from .factory import quanter
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,23 +13,24 @@
# limitations under the License.
from . import qat
from .qat import *
from .qat import ImperativeQuantAware
from . import ptq
from .ptq import *
from .ptq import ImperativePTQ
from . import ptq_config
from .ptq_config import *
from .ptq_config import PTQConfig, default_ptq_config
from . import ptq_quantizer
from .ptq_quantizer import *
from .ptq_quantizer import (
BaseQuantizer,
AbsmaxQuantizer,
PerChannelAbsmaxQuantizer,
KLQuantizer,
HistQuantizer,
SUPPORT_ACT_QUANTIZERS,
SUPPORT_WT_QUANTIZERS,
)
from . import ptq_registry
from .ptq_registry import *
__all__ = []
__all__ += qat.__all__
__all__ += ptq.__all__
__all__ += ptq_config.__all__
__all__ += ptq_quantizer.__all__
__all__ += ptq_registry.__all__
from .ptq_registry import PTQRegistry
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,8 +13,10 @@
# limitations under the License.
import copy
import paddle
import paddle.nn as nn
from . import utils
......@@ -66,7 +68,7 @@ def fuse_layers(model, layers_to_fuse, inplace=False):
Return
fused_model(paddle.nn.Layer): The fused model.
'''
if inplace == False:
if inplace is False:
model = copy.deepcopy(model)
for layers in layers_to_fuse:
_fuse_layers(model, layers)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,24 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import copy
import logging
import os
import numpy as np
import paddle
import paddle.nn.quant.quant_layers as quant_layers
from paddle.fluid.log_helper import get_logger
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from . import fuse_utils
from . import utils
from . import ptq_hooks
from . import ptq_config
from . import ptq_quantizer
from ...static.log_helper import get_logger
from ...static.quantization.utils import (
_get_input_name_index,
_get_op_input_var_names,
_get_op_output_var_names,
_get_output_name_index,
)
from . import fuse_utils, ptq_config, ptq_hooks, ptq_quantizer, utils
from .ptq_registry import PTQRegistry
__all__ = ['ImperativePTQ']
INFER_MODEL_SUFFIX = ".pdmodel"
INFER_PARAMS_SUFFIX = ".pdiparams"
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
......@@ -165,8 +168,8 @@ class ImperativePTQ:
infer_program,
feed_target_names,
fetch_targets,
] = paddle.fluid.io.load_inference_model(
dirname=dirname,
] = paddle.static.load_inference_model(
path_prefix=dirname,
executor=exe,
model_filename=model_filename,
params_filename=params_filename,
......@@ -178,14 +181,23 @@ class ImperativePTQ:
self._remove_scale_op(infer_program)
# Save final program
paddle.fluid.io.save_inference_model(
dirname=dirname,
feeded_var_names=feed_target_names,
target_vars=fetch_targets,
model_name = None
if model_filename is None:
model_name = "model"
elif model_filename.endswith(".pdmodel"):
model_name = model_filename.rsplit(".", 1)[0]
else:
model_name = model_filename
path_prefix = os.path.join(dirname, model_name)
feed_vars = [
infer_program.global_block().var(name) for name in feed_target_names
]
paddle.static.save_inference_model(
path_prefix,
feed_vars,
fetch_targets,
executor=exe,
main_program=infer_program.clone(),
model_filename=model_filename,
params_filename=params_filename,
program=infer_program.clone(),
)
if is_dynamic_mode:
......@@ -302,7 +314,7 @@ class ImperativePTQ:
) and PTQRegistry.is_simulated_quant_layer(sub_layer):
quant_config = sub_layer._quant_config
assert quant_config.enable_in_act_quantizer == True
assert quant_config.enable_in_act_quantizer is True
wt_quantizer = quant_config.wt_quantizer
in_act_quantizer = quant_config.in_act_quantizer
......@@ -376,7 +388,7 @@ class ImperativePTQ:
None
"""
for op in utils.program_all_ops(program):
for in_var_name in utils._get_op_input_var_names(op):
for in_var_name in _get_op_input_var_names(op):
previous_op = utils.find_previous_op(op.block, in_var_name)
if previous_op is None:
continue
......@@ -388,20 +400,16 @@ class ImperativePTQ:
attr_name = previous_op.output('OutScale')[0]
in_threshold = utils.load_variable_data(scope, attr_name)
in_threshold = utils.fp_numpy_to_naive(in_threshold)
argname, index = utils._get_input_name_index(
op, in_var_name
)
argname, index = _get_input_name_index(op, in_var_name)
op._set_attr(
argname + str(index) + "_threshold", in_threshold
)
op._set_attr("with_quant_attr", True)
else:
for out_var_name in utils._get_op_output_var_names(
previous_op
):
for out_var_name in _get_op_output_var_names(previous_op):
if out_var_name != in_var_name:
continue
argname, index = utils._get_output_name_index(
argname, index = _get_output_name_index(
previous_op, out_var_name
)
attr_name = argname + str(index) + "_threshold"
......@@ -409,9 +417,7 @@ class ImperativePTQ:
continue
threshold = previous_op.attr(attr_name)
argname, index = utils._get_input_name_index(
op, in_var_name
)
argname, index = _get_input_name_index(op, in_var_name)
attr_name = argname + str(index) + "_threshold"
op._set_attr(attr_name, threshold)
op._set_attr("with_quant_attr", True)
......@@ -453,10 +459,10 @@ class ImperativePTQ:
continue
next_op = next_ops[0]
argname, index = utils._get_output_name_index(op, out_var_name)
argname, index = _get_output_name_index(op, out_var_name)
old_attr_name = argname + str(index) + "_threshold"
argname, index = utils._get_output_name_index(
argname, index = _get_output_name_index(
next_op, next_op.output("Out")[0]
)
new_attr_name = argname + str(index) + "_threshold"
......@@ -478,7 +484,7 @@ class ImperativePTQ:
@staticmethod
def _is_skip_layer(layer):
return hasattr(layer, "skip_quant") and layer.skip_quant == True
return hasattr(layer, "skip_quant") and layer.skip_quant is True
@staticmethod
def _is_quant_layer(layer):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import copy
import paddle
from .ptq_quantizer import *
__all__ = ['PTQConfig', 'default_ptq_config']
from .ptq_quantizer import (
SUPPORT_ACT_QUANTIZERS,
SUPPORT_WT_QUANTIZERS,
KLQuantizer,
PerChannelAbsmaxQuantizer,
)
class PTQConfig:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,12 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import math
import numpy as np
from . import ptq_config
from .ptq_registry import PTQRegistry
def quant_forward_post_hook(layer, inputs, outputs):
"""
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,24 +13,14 @@
# limitations under the License.
import abc
import copy
import math
import numpy as np
import paddle
from ...static.quantization.cal_kl_threshold import cal_kl_threshold
from . import utils
from ..cal_kl_threshold import cal_kl_threshold
__all__ = [
'BaseQuantizer',
'AbsmaxQuantizer',
'PerChannelAbsmaxQuantizer',
'KLQuantizer',
'HistQuantizer',
'SUPPORT_ACT_QUANTIZERS',
'SUPPORT_WT_QUANTIZERS',
]
def abs_max_value(tensor):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,8 +14,6 @@
import paddle
__all__ = ['PTQRegistry']
class LayerInfo:
"""
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,35 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
import numpy as np
import sys
import os
import warnings
import paddle
import paddle.nn as nn
import paddle.nn.quant.quant_layers as quant_layers
from paddle.fluid import dygraph, core, framework, unique_name
from paddle.fluid.framework import IrGraph
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.fluid.io import load_inference_model, save_inference_model
from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass
from paddle.fluid.log_helper import get_logger
from .. import quantization_pass
from ..utils import move_persistable_var_to_global_block
from . import utils
from . import fuse_utils
__all__ = ['ImperativeQuantAware']
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
from paddle.framework import core
from ...static.quantization.quantization_pass import (
QuantWeightPass,
ReplaceFakeQuantDequantPass,
)
from ...static.quantization.utils import (
_get_input_name_index,
_get_op_input_var_names,
_get_output_name_index,
move_persistable_var_to_global_block,
)
from . import fuse_utils, utils
INFER_MODEL_SUFFIX = ".pdmodel"
INFER_PARAMS_SUFFIX = ".pdiparams"
def lazy_import_fleet(layer_name_map, fake_quant_input_layers):
......@@ -147,7 +139,7 @@ class ImperativeQuantAware:
.. code-block:: python
import paddle
from paddle.fluid.contrib.slim.quantization \
from paddle.static.quantization \
import ImperativeQuantAware
from paddle.vision.models \
import resnet
......@@ -178,7 +170,7 @@ class ImperativeQuantAware:
.. code-block:: python
import paddle
from paddle.fluid.contrib.slim.quantization \
from paddle.static.quantization \
import ImperativeQuantAware
class ImperativeModel(paddle.nn.Layer):
......@@ -256,7 +248,7 @@ class ImperativeQuantAware:
.. code-block:: python
import paddle
from paddle.fluid.contrib.slim.quantization \
from paddle.static.quantization \
import ImperativeQuantAware
class ImperativeModel(paddle.nn.Layer):
......@@ -288,8 +280,8 @@ class ImperativeQuantAware:
imperative_qat.quantize(model)
"""
assert isinstance(
model, dygraph.Layer
), "The model must be the instance of dygraph.Layer."
model, paddle.nn.Layer
), "The model must be the instance of paddle.nn.Layer."
if self.fuse_conv_bn:
fuse_utils.fuse_conv_bn(model)
......@@ -376,7 +368,7 @@ class ImperativeQuantizeInputs:
), "activation_bits should be 1, 2,... or 16."
layer_check = lambda method: method is None or issubclass(
method, dygraph.layers.Layer
method, paddle.nn.Layer
)
assert layer_check(
weight_preprocess_layer
......@@ -417,13 +409,13 @@ class ImperativeQuantizeInputs:
"""
assert isinstance(
model, dygraph.Layer
), "The model must be the instance of dygraph.Layer."
model, paddle.nn.Layer
), "The model must be the instance of paddle.nn.Layer."
for name, cur_layer in model.named_sublayers():
if not isinstance(cur_layer, self._quantizable_layer_type) or (
hasattr(cur_layer, "skip_quant")
and cur_layer.skip_quant == True
and cur_layer.skip_quant is True
):
continue
......@@ -480,8 +472,8 @@ class ImperativeQuantizeOutputs:
None
"""
assert isinstance(
model, dygraph.Layer
), "The model must be the instance of dygraph.Layer."
model, paddle.nn.Layer
), "The model must be the instance of paddle.nn.Layer."
for cur_name, cur_layer in model.named_sublayers():
if '_act_preprocess' in cur_name:
......@@ -535,8 +527,8 @@ class ImperativeQuantizeOutputs:
None
"""
assert isinstance(
model, dygraph.Layer
), "The model must be the instance of dygraph.Layer."
model, paddle.nn.Layer
), "The model must be the instance of paddle.nn.Layer."
paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config)
......@@ -546,8 +538,8 @@ class ImperativeQuantizeOutputs:
paddle.enable_static()
place = core.CPUPlace()
scope = global_scope()
exe = Executor(place)
scope = paddle.static.global_scope()
exe = paddle.static.Executor(place)
dirname = os.path.dirname(path)
basename = os.path.basename(path)
......@@ -558,8 +550,8 @@ class ImperativeQuantizeOutputs:
infer_program,
feed_target_names,
fetch_targets,
] = load_inference_model(
dirname=dirname,
] = paddle.static.load_inference_model(
dirname,
executor=exe,
model_filename=model_filename,
params_filename=params_filename,
......@@ -600,14 +592,23 @@ class ImperativeQuantizeOutputs:
move_persistable_var_to_global_block(infer_program)
save_inference_model(
dirname=dirname,
feeded_var_names=feed_target_names,
target_vars=fetch_targets,
model_name = None
if model_filename is None:
model_name = "model"
elif model_filename.endswith(".pdmodel"):
model_name = model_filename.rsplit(".", 1)[0]
else:
model_name = model_filename
path_prefix = os.path.join(dirname, model_name)
feed_vars = [
infer_program.global_block().var(name) for name in feed_target_names
]
paddle.static.save_inference_model(
path_prefix,
feed_vars,
fetch_targets,
executor=exe,
main_program=infer_program.clone(),
model_filename=model_filename,
params_filename=params_filename,
program=infer_program.clone(),
clip_extra=clip_extra,
)
......@@ -619,7 +620,7 @@ class ImperativeQuantizeOutputs:
Whether the layer needs to calculate output scales.
"""
# exclude fake_quant ops in quant_layers file
if not isinstance(layer, dygraph.Layer):
if not isinstance(layer, paddle.nn.Layer):
return False
if self._onnx_format:
......@@ -660,7 +661,7 @@ class ImperativeQuantizeOutputs:
target_ops.append(op)
for op in target_ops:
for in_var_name in utils._get_op_input_var_names(op):
for in_var_name in _get_op_input_var_names(op):
previous_op = utils.find_previous_op(op.block, in_var_name)
if previous_op is not None and (
......@@ -670,9 +671,7 @@ class ImperativeQuantizeOutputs:
scale_name = previous_op.output('OutScale')[0]
in_scale = utils.load_variable_data(scope, scale_name)
in_scale = utils.fp_numpy_to_naive(in_scale)
argname, index = utils._get_input_name_index(
op, in_var_name
)
argname, index = _get_input_name_index(op, in_var_name)
op._set_attr(
argname + str(index) + "_threshold", in_scale
)
......@@ -697,7 +696,7 @@ class ImperativeQuantizeOutputs:
out_scale = utils.fp_numpy_to_naive(out_scale)
if previous_op.type != "feed":
res = utils._get_output_name_index(previous_op, in_var_name)
res = _get_output_name_index(previous_op, in_var_name)
if res is not None:
argname, index = res
previous_op._set_attr(
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,19 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
import paddle.nn.quant.quant_layers as quant_layers
from ..utils import (
_get_op_input_var_names,
_get_op_output_var_names,
_get_output_name_index,
_get_input_name_index,
)
layer_name_map = {
'Conv2DTranspose': paddle.nn.Conv2DTranspose,
'Conv2D': paddle.nn.Conv2D,
......@@ -42,7 +34,6 @@ layer_name_map = {
'Softmax': paddle.nn.Softmax,
'Swish': paddle.nn.Swish,
'Tanh': paddle.nn.Tanh,
'Hardswish': paddle.nn.Hardswish,
'BatchNorm': paddle.nn.BatchNorm,
'GroupNorm': paddle.nn.GroupNorm,
'LayerNorm': paddle.nn.LayerNorm,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import quantize_transpiler
from .quantize_transpiler import *
import logging
__all__ = quantize_transpiler.__all__
def get_logger(name, level, fmt=None):
"""
Get logger from logging with given name, level and format without
setting logging basicConfig. For setting basicConfig in paddle
will disable basicConfig setting after import paddle.
Args:
name (str): The logger name.
level (logging.LEVEL): The base level of the logger
fmt (str): Format of logger output
Returns:
logging.Logger: logging logger with given settings
Examples:
.. code-block:: python
import paddle
import logging
logger = paddle.static.log_helper.get_logger(__name__, logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
"""
logger = logging.getLogger(name)
logger.setLevel(level)
handler = logging.StreamHandler()
if fmt:
formatter = logging.Formatter(fmt=fmt, datefmt='%a %b %d %H:%M:%S')
handler.setFormatter(formatter)
logger.addHandler(handler)
# stop propagate for propagating may print
# log multiple times
logger.propagate = False
return logger
......@@ -12,50 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ...fluid.contrib.slim.quantization.quantization_pass import (
from .quantization_pass import (
QuantizationTransformPass,
)
from ...fluid.contrib.slim.quantization.quantization_pass import (
from .quantization_pass import (
QuantizationFreezePass,
)
from ...fluid.contrib.slim.quantization.quantization_pass import (
from .quantization_pass import (
ConvertToInt8Pass,
)
from ...fluid.contrib.slim.quantization.quantization_pass import (
from .quantization_pass import (
TransformForMobilePass,
)
from ...fluid.contrib.slim.quantization.quantization_pass import (
from .quantization_pass import (
OutScaleForTrainingPass,
)
from ...fluid.contrib.slim.quantization.quantization_pass import (
from .quantization_pass import (
OutScaleForInferencePass,
)
from ...fluid.contrib.slim.quantization.quantization_pass import (
from .quantization_pass import (
AddQuantDequantPass,
)
from ...fluid.contrib.slim.quantization.quantization_pass import (
from .quantization_pass import (
ReplaceFakeQuantDequantPass,
)
from ...fluid.contrib.slim.quantization.quantization_pass import QuantWeightPass
from ...fluid.contrib.slim.quantization.quantization_pass import (
from .quantization_pass import (
QuantWeightPass,
)
from .quantization_pass import (
QuantizationTransformPassV2,
)
from ...fluid.contrib.slim.quantization.quantization_pass import (
from .quantization_pass import (
AddQuantDequantPassV2,
)
from ...fluid.contrib.slim.quantization.quant_int8_mkldnn_pass import (
from .quantization_pass import (
AddQuantDequantForInferencePass,
)
from .quant_int8_mkldnn_pass import (
QuantInt8MkldnnPass,
)
from ...fluid.contrib.slim.quantization.quant2_int8_mkldnn_pass import (
from .quant2_int8_mkldnn_pass import (
Quant2Int8MkldnnPass,
)
from ...fluid.contrib.slim.quantization.post_training_quantization import (
from .post_training_quantization import (
PostTrainingQuantization,
)
from ...fluid.contrib.slim.quantization.post_training_quantization import (
from .post_training_quantization import (
PostTrainingQuantizationProgram,
)
from ...fluid.contrib.slim.quantization.post_training_quantization import (
from .post_training_quantization import (
WeightQuantization,
)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,25 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import time
import sys
import logging
import paddle
import sys
import time
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.static as static
from ....log_helper import get_logger
from ..log_helper import get_logger
from .utils import (
_channelwise_quant_axis1_ops,
bias_correction_w,
calculate_quant_cos_error,
dequant_tensor,
load_variable_data,
quant_tensor,
set_variable_data,
stable_sigmoid,
quant_tensor,
dequant_tensor,
_channelwise_quant_axis1_ops,
calculate_quant_cos_error,
bias_correction_w,
)
_logger = get_logger(
......@@ -42,7 +42,7 @@ ZETA = 1.1
def compute_soft_rounding(alpha_v):
return fluid.layers.clip(
return paddle.clip(
paddle.nn.functional.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA,
min=0,
max=1,
......@@ -83,11 +83,9 @@ class AdaRoundLoss:
return round_loss
round_loss = paddle.static.nn.cond(
round_loss = static.nn.cond(
warm_start,
lambda: fluid.layers.fill_constant(
shape=[1], dtype='float32', value=0.0
),
lambda: paddle.full(shape=[1], dtype='float32', fill_value=0.0),
round_loss_fn,
)
......@@ -151,7 +149,7 @@ class AdaRound:
shape=alpha.shape,
dtype="float32",
name=var_name + ".alpha",
default_initializer=fluid.initializer.NumpyArrayInitializer(alpha),
default_initializer=paddle.nn.initializer.Assign(alpha),
)
def _calculate_output_with_adarounded_weights(
......@@ -258,12 +256,12 @@ def run_adaround(
fetch_op_name = quant_op_out_name
# build adaround program
exec_strategy = fluid.ExecutionStrategy()
exec_strategy = static.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
startup_program = static.Program()
train_program = static.Program()
with static.program_guard(train_program, startup_program):
with paddle.utils.unique_name.guard():
# initialize adaround
adaround = AdaRound(
scale,
......@@ -273,21 +271,21 @@ def run_adaround(
weight_op_type=weight_op_type,
num_iterations=num_iterations,
)
orig_out_tensor = fluid.data(
orig_out_tensor = static.data(
name='orig_out_tensor',
shape=fp32_fetch_list.shape,
shape=(-1,) + fp32_fetch_list.shape,
dtype='float32',
)
adaround_out_tensor = fluid.data(
adaround_out_tensor = static.data(
name='adaround_out_tensor',
shape=fp32_fetch_list.shape,
shape=(-1,) + fp32_fetch_list.shape,
dtype='float32',
)
beta_tensor = fluid.data(
name='beta', shape=[1], dtype='float32'
beta_tensor = static.data(
name='beta', shape=[-1, 1], dtype='float32'
)
warm_start_tensor = fluid.data(
name='warm_start', shape=[1], dtype='bool'
warm_start_tensor = static.data(
name='warm_start', shape=[-1, 1], dtype='bool'
)
train_fetches_loss = adaround.get_loss(
......@@ -296,7 +294,7 @@ def run_adaround(
adaround_out_tensor,
orig_out_tensor,
)
optimizer = fluid.optimizer.Adam(learning_rate=lr)
optimizer = paddle.optimizer.Adam(learning_rate=lr)
loss = train_fetches_loss['loss']
optimizer.minimize(loss)
exe.run(startup_program)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,15 +14,15 @@
import logging
import math
import numpy as np
from ....log_helper import get_logger
from ..log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
__all__ = ['cal_kl_threshold']
def expand_quantized_bins(quantized_bins, reference_bins):
'''
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,43 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import re
import math
import shutil
import logging
import numpy as np
try:
from tqdm import tqdm
except:
from .utils import tqdm
from inspect import isgeneratorfunction
from .... import io
from .... import core
from .... import reader
from .... import framework
from .... import unique_name
from ....executor import global_scope, Executor
from ....framework import IrGraph
from ....log_helper import get_logger
from paddle.fluid.framework import IrGraph, _get_var
from ... import io, static
from ...fluid import reader
from ...framework import core
from ...utils import unique_name
from ..log_helper import get_logger
from . import utils
from .adaround import run_adaround
from .cal_kl_threshold import cal_kl_threshold
from .quantization_pass import (
AddQuantDequantPass,
AddQuantDequantPassV2,
QuantizationFreezePass,
QuantizationTransformPass,
QuantizationTransformPassV2,
QuantizationFreezePass,
QuantWeightPass,
AddQuantDequantPass,
AddQuantDequantPassV2,
)
from .cal_kl_threshold import cal_kl_threshold
from .adaround import run_adaround
from . import utils
__all__ = [
'PostTrainingQuantization',
'WeightQuantization',
'PostTrainingQuantizationProgram',
]
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
......@@ -156,10 +150,10 @@ class PostTrainingQuantization:
Constructor.
Args:
executor(fluid.Executor): The executor to load, run and save the
executor(static.Executor): The executor to load, run and save the
quantized model.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
scope(static.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by static.global_scope().
model_dir(str): The path of the fp32 model that will be quantized,
and the model and params files are under the path.
model_filename(str, optional): The name of file to load the inference
......@@ -245,10 +239,10 @@ class PostTrainingQuantization:
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
import paddle.static as static
from paddle.static.quantization import PostTrainingQuantization
exe = fluid.Executor(fluid.CPUPlace())
exe = static.Executor(paddle.CPUPlace())
model_dir = path/to/fp32_model_params
# set model_filename as None when the filename is __model__,
# otherwise set it as the real filename
......@@ -344,7 +338,7 @@ class PostTrainingQuantization:
# Save input params
self._bias_correction = bias_correction
self._executor = executor
self._scope = global_scope() if scope is None else scope
self._scope = static.global_scope() if scope is None else scope
self._model_dir = model_dir
self._model_filename = model_filename
self._params_filename = params_filename
......@@ -537,22 +531,29 @@ class PostTrainingQuantization:
Args:
save_model_path(str): The path to save the quantized model.
model_filename(str, optional): If the model_filename is None,
save the model to '__model__'. Otherwise, save the model
to the specified filename. Default: None.
params_filename(str, optional): If the params_filename is None,
save params to separted files. Otherwise, save all params
to the specified filename.
save the model to 'model.pdmodel' and 'model.pdiparams'. Otherwise, save the model to 'model_name.pdmodel' and
'model_name.pdiparams". Default: None.
Returns:
None
'''
io.save_inference_model(
dirname=save_model_path,
model_filename=model_filename,
params_filename=params_filename,
feeded_var_names=self._feed_list,
target_vars=self._fetch_list,
model_name = None
if model_filename is None:
model_name = "model"
elif model_filename.endswith(".pdmodel"):
model_name = model_filename.rsplit(".", 1)[0]
else:
model_name = model_filename
path_prefix = os.path.join(save_model_path, model_name)
feed_vars = [
self._program.global_block().var(name) for name in self._feed_list
]
static.save_inference_model(
path_prefix,
feed_vars,
self._fetch_list,
executor=self._executor,
main_program=self._program,
program=self._program,
clip_extra=self._clip_extra,
)
_logger.info("The quantized model is saved in " + save_model_path)
......@@ -567,8 +568,8 @@ class PostTrainingQuantization:
self._program,
self._feed_list,
self._fetch_list,
] = io.load_inference_model(
dirname=self._model_dir,
] = static.load_inference_model(
self._model_dir,
executor=self._executor,
model_filename=self._model_filename,
params_filename=self._params_filename,
......@@ -578,7 +579,7 @@ class PostTrainingQuantization:
self._optimize_fp32_model()
feed_vars = [
framework._get_var(str(var_name), self._program)
_get_var(str(var_name), self._program)
for var_name in self._feed_list
]
......@@ -1632,17 +1633,17 @@ class WeightQuantization:
# Load model
place = core.CPUPlace()
exe = Executor(place)
scope = global_scope()
[infer_program, feed_list, fetch_list] = io.load_inference_model(
dirname=self._model_dir,
exe = static.Executor(place)
scope = static.global_scope()
[infer_program, feed_list, fetch_list] = static.load_inference_model(
self._model_dir,
executor=exe,
model_filename=self._model_filename,
params_filename=self._params_filename,
)
# Clone and save fp16 weights
save_program = framework.Program()
save_program = static.Program()
save_block = save_program.global_block()
save_var_map = {}
......@@ -1723,10 +1724,10 @@ class WeightQuantization:
"""
# Load model
place = core.CPUPlace()
exe = Executor(place)
scope = global_scope()
[program, feed_list, fetch_list] = io.load_inference_model(
dirname=self._model_dir,
exe = static.Executor(place)
scope = static.global_scope()
[program, feed_list, fetch_list] = static.load_inference_model(
self._model_dir,
executor=exe,
model_filename=self._model_filename,
params_filename=self._params_filename,
......@@ -1758,15 +1759,22 @@ class WeightQuantization:
self._weight_channel_wise_abs_max_quantization(
scope, place, weight_bits, op, var_name, for_test
)
io.save_inference_model(
dirname=save_model_dir,
feeded_var_names=feed_list,
target_vars=fetch_list,
model_name = None
if save_model_filename is None:
model_name = "model"
elif save_model_filename.endswith(".pdmodel"):
model_name = save_model_filename.rsplit(".", 1)[0]
else:
model_name = save_model_filename
path_prefix = os.path.join(save_model_dir, model_name)
feed_vars = [program.global_block().var(name) for name in feed_list]
static.save_inference_model(
path_prefix,
feed_vars,
fetch_list,
executor=exe,
main_program=program,
model_filename=save_model_filename,
params_filename=save_params_filename,
program=program,
)
def _weight_abs_max_quantization(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,11 +13,9 @@
# limitations under the License.
import numpy as np
from .... import core
from ....framework import IrGraph
from ....framework import _get_paddle_place
__all__ = ['Quant2Int8MkldnnPass']
from ...fluid.framework import IrGraph
from ...framework import _get_paddle_place, core
OpRole = core.op_proto_and_checker_maker.OpRole
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,12 +13,9 @@
# limitations under the License.
import numpy as np
from .... import core
from ....framework import IrGraph
from ....framework import IrNode
from ....framework import _get_paddle_place
__all__ = ['QuantInt8MkldnnPass']
from ...fluid.framework import IrGraph
from ...framework import _get_paddle_place
class QuantInt8MkldnnPass:
......@@ -40,23 +37,23 @@ class QuantInt8MkldnnPass:
def __init__(self, _scope=None, _place=None):
r"""
Args:
scope(fluid.Scope): scope is used to initialize the new parameters.
place(fluid.CPUPlace|str): place is used to initialize the new parameters.
scope(static.Scope): scope is used to initialize the new parameters.
place(static.CPUPlace|str): place is used to initialize the new parameters.
When it is string, it can be only 'cpu'.
Examples:
.. code-block:: python
# The original graph will be rewrite.
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \
import paddle.static as static
from paddle.static.quantization \
import QuantInt8MkldnnPass
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
from paddle.framework import core
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace()
mkldnn_pass = QuantInt8MkldnnPass(fluid.global_scope(),
graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
place = static.CPUPlace()
mkldnn_pass = QuantInt8MkldnnPass(static.global_scope(),
place)
mkldnn_pass.apply(graph)
"""
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,39 +13,21 @@
# limitations under the License.
import collections
import numpy as np
try:
from tqdm import tqdm
except:
from .utils import tqdm
from .... import core
from ....framework import IrGraph
from ....framework import IrNode
from ....framework import Operator
from .... import unique_name
from ....framework import Program, program_guard, default_startup_program
from ....data import data
from ....executor import scope_guard
from ....framework import _get_paddle_place
from . import utils
import paddle
__all__ = [
'QuantizationTransformPass',
'QuantizationFreezePass',
'ConvertToInt8Pass',
'TransformForMobilePass',
'OutScaleForTrainingPass',
'OutScaleForInferencePass',
'AddQuantDequantPass',
'QuantizationTransformPassV2',
'AddQuantDequantPassV2',
'ReplaceFakeQuantDequantPass',
'QuantWeightPass',
'AddQuantDequantForInferencePass',
]
from ...fluid.framework import IrGraph, IrNode
from ...framework import _get_paddle_place, core
from ...static import Program, data, program_guard, scope_guard
from ...utils import unique_name
from . import utils
_fake_quant_op_list = [
'fake_quantize_abs_max',
......@@ -137,10 +119,10 @@ class QuantizationTransformPass:
Constructor.
Args:
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
scope(static.Scope): When activation use 'range_abs_max' as the quantize
type, this pass will create some new parameters. The scope is used to
initialize these new parameters.
place(fluid.CPUPlace|fluid.CUDAPlace|str): place is used to initialize new
place(static.CPUPlace|static.CUDAPlace|str): place is used to initialize new
parameters described above. If it's string, It can be ``cpu``, and ``gpu:x``,
where ``x`` is the index of the GPUs.
weight_bits(int): quantization bit number for weights,
......@@ -197,15 +179,15 @@ class QuantizationTransformPass:
Examples:
.. code-block:: python
# The original graph will be rewrite.
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \
import paddle.static as static
from paddle.static.quantization \
import QuantizationTransformPass
from paddle.fluid.contrib.slim.graph import IrGraph
from paddle.fluid import core
from paddle.fluid.framework import IrGraph
from paddle.framework import core
graph = IrGraph(core.Graph(program.desc), for_test=False)
place = fluid.CPUPlace()
transform_pass = QuantizationTransformPass(fluid.global_scope(),
graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
place = paddle.CPUPlace()
transform_pass = QuantizationTransformPass(static.global_scope(),
place)
transform_pass.apply(graph)
"""
......@@ -1094,8 +1076,8 @@ class QuantizationFreezePass:
and weight will be scaled offline.
Args:
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace|str): place is used to restore the weight tensors.
scope(static.Scope): scope is used to get the weight tensor values.
place(static.CPUPlace|static.CUDAPlace|str): place is used to restore the weight tensors.
If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
bias_correction(bool): whether use bias correction for post-training quantization.
https://arxiv.org/abs/1810.05723.
......@@ -1190,7 +1172,7 @@ class QuantizationFreezePass:
)
quantized_param_v = np.round(quantized_param_v)
# Weight bias correction
if self._bias_correction == True:
if self._bias_correction is True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
......@@ -1459,8 +1441,8 @@ class ConvertToInt8Pass:
Convert the weights into int8_t type.
Args:
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace|str): place is used to restore the
scope(static.Scope): scope is used to get the weight tensor values.
place(static.CPUPlace|static.CUDAPlace|str): place is used to restore the
8bits weight tensors. If it's string, It can be ``cpu``, and ``gpu:x``,
where ``x`` is the index of the GPUs.
quantizable_op_type(list[str]): This input param will be removed latter. The pass
......@@ -1602,8 +1584,8 @@ class OutScaleForTrainingPass:
These output scales may be used by tensorRT or some other inference engines.
Args:
scope(fluid.Scope): The scope is used to initialize these new parameters.
place(fluid.CPUPlace|fluid.CUDAPlace|str): The place is used to initialize new parameters.
scope(static.Scope): The scope is used to initialize these new parameters.
place(static.CPUPlace|static.CUDAPlace|str): The place is used to initialize new parameters.
If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the
index of the GPUs.
moving_rate(float): The decay coefficient of moving average. The default value is 0.9.
......@@ -1764,7 +1746,7 @@ class OutScaleForInferencePass:
These output scales may be used by tensorRT or some other inference engines.
Args:
scope(fluid.Scope): The scope is used to initialize these new parameters.
scope(static.Scope): The scope is used to initialize these new parameters.
"""
self._scope = scope
self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST
......@@ -1856,8 +1838,8 @@ class AddQuantDequantPass:
Constructor.
Args:
scope(fluid.Scope): The scope is used to initialize these new parameters.
place(fluid.CPUPlace|fluid.CUDAPlace|str): place is used to initialize new
scope(static.Scope): The scope is used to initialize these new parameters.
place(static.CPUPlace|static.CUDAPlace|str): place is used to initialize new
parameters described above. If ``place`` is string, it can be It can be ``cpu``
or ``gpu:x``, where ``x`` is the index of the GPUs.
moving_rate(float, optional): the param for 'quant_dequant_moving_average_abs_max'
......@@ -2452,12 +2434,12 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
.. code-block:: python
# The original graph will be rewrite.
import paddle
from paddle.fluid.contrib.slim.quantization \
from paddle.static.quantization \
import QuantizationTransformPassV2
from paddle.fluid.contrib.slim.graph import IrGraph
from paddle.fluid import core
from paddle.fluid.framework import IrGraph
from paddle.framework import core
graph = IrGraph(core.Graph(program.desc), for_test=False)
graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
place = paddle.CPUPlace()
scope = paddle.static.global_scope()
transform_pass = QuantizationTransformPassV2(scope, place)
......@@ -2810,12 +2792,12 @@ class AddQuantDequantPassV2:
.. code-block:: python
# The original graph will be rewrite.
import paddle
from paddle.fluid.contrib.slim.quantization \
from paddle.static.quantization \
import AddQuantDequantPassV2
from paddle.fluid.contrib.slim.graph import IrGraph
from paddle.fluid import core
from paddle.fluid.framework import IrGraph
from paddle.framework import core
graph = IrGraph(core.Graph(program.desc), for_test=False)
graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
place = paddle.CPUPlace()
scope = paddle.static.global_scope()
add_quant_dequant_pass = AddQuantDequantPassV2(scope, place)
......@@ -2977,12 +2959,12 @@ class ReplaceFakeQuantDequantPass:
.. code-block:: python
# The original graph will be rewrite.
import paddle
from paddle.fluid.contrib.slim.quantization \
from paddle.static.quantization \
import ReplaceFakeQuantDequantPass
from paddle.fluid.contrib.slim.graph import IrGraph
from paddle.fluid import core
from paddle.fluid.framework import IrGraph
from paddle.framework import core
graph = IrGraph(core.Graph(program.desc), for_test=False)
graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
place = paddle.CPUPlace()
scope = paddle.static.global_scope()
replace_pass = ReplaceFakeQuantDequantPass(scope, place)
......@@ -3133,12 +3115,12 @@ class QuantWeightPass:
.. code-block:: python
# The original graph will be rewrite.
import paddle
from paddle.fluid.contrib.slim.quantization \
from paddle.static.quantization \
import QuantWeightPass
from paddle.fluid.contrib.slim.graph import IrGraph
from paddle.fluid import core
from paddle.fluid.framework import IrGraph
from paddle.framework import core
graph = IrGraph(core.Graph(program.desc), for_test=False)
graph = IrGraph(core.Graph(paddle.static.Program().desc), for_test=False)
place = paddle.CPUPlace()
scope = paddle.static.global_scope()
quant_weight_pass = QuantWeightPass(scope, place)
......@@ -3207,7 +3189,7 @@ class QuantWeightPass:
bits_length,
onnx_format=True,
)
if self._bias_correction == True:
if self._bias_correction is True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
......@@ -3264,7 +3246,7 @@ class AddQuantDequantForInferencePass:
def __init__(self, scope, place, quant_bits=8):
"""
Args:
scope(fluid.Scope): The scope is used to initialize these new parameters.
scope(static.Scope): The scope is used to initialize these new parameters.
place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to restore the weight tensors.
If it's string, it can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
quant_bits(int, optional): quantization bit number for weight. Default is 8.
......
......@@ -250,7 +250,6 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model)
list(REMOVE_ITEM TEST_OPS test_imperative_ptq)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2)
list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp)
list(REMOVE_ITEM TEST_OPS test_imperative_qat_lsq)
list(REMOVE_ITEM TEST_OPS test_imperative_qat_matmul)
......
......@@ -91,17 +91,18 @@ Having gathered all the data needed for quantization we apply the `cpu_quantize_
The code snipped shows how the `Quant2Int8MkldnnPass` can be applied to a model graph:
```python
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
import paddle
import paddle.static as static
from paddle.static.quantization import Quant2Int8MkldnnPass
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
from paddle.framework import core
# Create the IrGraph by Program
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace()
graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
place = paddle.CPUPlace()
# Convert the IrGraph to MKL-DNN supported INT8 IrGraph using the
# Quant2Int8MkldnnPass. It requires a list of operators to be quantized
mkldnn_pass = Quant2Int8MkldnnPass({'conv2d', 'pool2d'}, fluid.global_scope(), place, fluid.core, False)
mkldnn_pass = Quant2Int8MkldnnPass({'conv2d', 'pool2d'}, static.global_scope(), place, core, False)
# Apply Quant2Int8MkldnnPass to IrGraph
mkldnn_pass.apply(graph)
......@@ -263,7 +264,7 @@ The following options are also accepted:
```bash
cd /PATH/TO/PADDLE
OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py --quant_model=/PATH/TO/DOWNLOADED/QUANT/MODEL --fp32_model=/PATH/TO/DOWNLOADED/FP32/MODEL --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.01 --ops_to_quantize="conv2d,pool2d"
OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/static/quantization/slim/tests/quant2_int8_image_classification_comparison.py --quant_model=/PATH/TO/DOWNLOADED/QUANT/MODEL --fp32_model=/PATH/TO/DOWNLOADED/FP32/MODEL --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.01 --ops_to_quantize="conv2d,pool2d"
```
> Notes: Due to a large amount of images in the `int8_full_val.bin` dataset (50 000), the accuracy benchmark may last long. To accelerate accuracy measuring, it is recommended to set `OMP_NUM_THREADS` to the maximum number of physical cores available on the server.
......@@ -276,7 +277,7 @@ To reproduce the performance results, the environment variable `OMP_NUM_THREADS=
```bash
cd /PATH/TO/PADDLE/build
python ../python/paddle/fluid/contrib/slim/tests/save_quant_model.py --quant_model_path=/PATH/TO/DOWNLOADED/QUANT/MODEL --int8_model_save_path=/PATH/TO/SAVE/QUANT/INT8/MODEL --ops_to_quantize="conv2d,pool2d"
python ../python/paddle/static/quantization/slim/tests/save_quant_model.py --quant_model_path=/PATH/TO/DOWNLOADED/QUANT/MODEL --int8_model_save_path=/PATH/TO/SAVE/QUANT/INT8/MODEL --ops_to_quantize="conv2d,pool2d"
```
2. Run the C-API test for performance benchmark.
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 paddlepaddle authors. all rights reserved.
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
......@@ -12,14 +12,14 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import argparse
import os
import sys
import argparse
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
import unittest
import paddle
from paddle.fluid.framework import IrGraph
from paddle.framework import core
paddle.enable_static()
......@@ -47,29 +47,32 @@ def parse_args():
def generate_dot_for_model(model_path, save_graph_dir, save_graph_name):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
inference_scope = paddle.static.global_scope()
with paddle.static.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(model_path, exe)
] = paddle.fluid.io.load_inference_model(model_path, exe)
else:
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params'
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename='model',
params_filename='params',
)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if not os.path.exists(save_graph_dir):
os.makedirs(save_graph_dir)
model_name = os.path.basename(os.path.normpath(save_graph_dir))
if save_graph_name is '':
if save_graph_name == '':
save_graph_name = model_name
graph.draw(save_graph_dir, save_graph_name, graph.all_op_nodes())
print(
......
......@@ -11,18 +11,27 @@
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import numpy as np
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.nn import Sequential
from paddle.nn import ReLU, ReLU6, LeakyReLU, Sigmoid, Softmax, PReLU
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D
from paddle.nn import BatchNorm1D
import numpy as np
from paddle.fluid.log_helper import get_logger
import paddle
from paddle.framework import ParamAttr
from paddle.nn import (
BatchNorm1D,
BatchNorm2D,
Conv2D,
LeakyReLU,
Linear,
MaxPool2D,
PReLU,
ReLU,
ReLU6,
Sequential,
Sigmoid,
Softmax,
)
from paddle.static.log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
......@@ -86,18 +95,18 @@ def train_lenet(lenet, reader, optimizer):
return loss_list
class ImperativeLenet(fluid.dygraph.Layer):
class ImperativeLenet(paddle.nn.Layer):
def __init__(self, num_classes=10):
super().__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
conv2d_w1_attr = ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = ParamAttr(name="conv2d_w_2")
fc_w1_attr = ParamAttr(name="fc_w_1")
fc_w2_attr = ParamAttr(name="fc_w_2")
fc_w3_attr = ParamAttr(name="fc_w_3")
conv2d_b2_attr = ParamAttr(name="conv2d_b_2")
fc_b1_attr = ParamAttr(name="fc_b_1")
fc_b2_attr = ParamAttr(name="fc_b_2")
fc_b3_attr = ParamAttr(name="fc_b_3")
self.features = Sequential(
Conv2D(
in_channels=1,
......@@ -155,26 +164,26 @@ class ImperativeLenet(fluid.dygraph.Layer):
x = self.quant_stub(inputs)
x = self.features(x)
x = paddle.flatten(x, 1, -1)
x = paddle.flatten(x, 1)
x = self.add(x, paddle.to_tensor(0.0)) # For CI
x = self.fc(x)
return x
class ImperativeLenetWithSkipQuant(fluid.dygraph.Layer):
class ImperativeLenetWithSkipQuant(paddle.nn.Layer):
def __init__(self, num_classes=10):
super().__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
conv2d_w1_attr = ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = ParamAttr(name="conv2d_w_2")
fc_w1_attr = ParamAttr(name="fc_w_1")
fc_w2_attr = ParamAttr(name="fc_w_2")
fc_w3_attr = ParamAttr(name="fc_w_3")
conv2d_b1_attr = ParamAttr(name="conv2d_b_1")
conv2d_b2_attr = ParamAttr(name="conv2d_b_2")
fc_b1_attr = ParamAttr(name="fc_b_1")
fc_b2_attr = ParamAttr(name="fc_b_2")
fc_b3_attr = ParamAttr(name="fc_b_3")
self.conv2d_0 = Conv2D(
in_channels=1,
out_channels=6,
......@@ -240,8 +249,7 @@ class ImperativeLenetWithSkipQuant(fluid.dygraph.Layer):
x = self.relu6_0(x)
x = self.pool2d_1(x)
x = paddle.flatten(x, 1, -1)
x = paddle.flatten(x, 1)
x = self.linear_0(x)
x = self.leaky_relu_0(x)
x = self.linear_1(x)
......@@ -252,7 +260,7 @@ class ImperativeLenetWithSkipQuant(fluid.dygraph.Layer):
return x
class ImperativeLinearBn(fluid.dygraph.Layer):
class ImperativeLinearBn(paddle.nn.Layer):
def __init__(self):
super().__init__()
......@@ -284,7 +292,7 @@ class ImperativeLinearBn(fluid.dygraph.Layer):
return x
class ImperativeLinearBn_hook(fluid.dygraph.Layer):
class ImperativeLinearBn_hook(paddle.nn.Layer):
def __init__(self):
super().__init__()
......
......@@ -12,19 +12,20 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import sys
import argparse
import logging
import os
import struct
import numpy as np
import sys
import time
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
from paddle.fluid import core
from paddle.framework import core
from paddle.static.quantization import Quant2Int8MkldnnPass
paddle.enable_static()
......@@ -185,23 +186,26 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
target='quant',
):
assert target in ['quant', 'int8', 'fp32']
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
inference_scope = paddle.static.global_scope()
with paddle.static.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(model_path, exe)
] = paddle.fluid.io.load_inference_model(model_path, exe)
else:
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params'
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename='model',
params_filename='params',
)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
......@@ -359,7 +363,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
return set(map(int, string.split(',')))
def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn():
if not core.is_compiled_with_mkldnn():
return
quant_model_path = test_case_args.quant_model
......
......@@ -13,15 +13,17 @@
# limitations under the License.
import argparse
import numpy as np
import struct
import sys
import time
import unittest
from paddle import fluid
from paddle.fluid.core import AnalysisConfig, create_paddle_predictor
import numpy as np
from save_quant_model import transform_and_save_int8_model
import paddle
from paddle.framework import core
def parse_args():
parser = argparse.ArgumentParser()
......@@ -80,17 +82,19 @@ class TestLstmModelPTQ(unittest.TestCase):
[len(feat) // 4 // 8, 8]
)
lod_feat = [feat.shape[0]]
minputs = fluid.create_lod_tensor(feat, [lod_feat], place)
minputs = paddle.fluid.create_lod_tensor(
feat, [lod_feat], place
)
infer_data = fluid.core.PaddleTensor()
infer_data = core.PaddleTensor()
infer_data.lod = minputs.lod()
infer_data.data = fluid.core.PaddleBuf(np.array(minputs))
infer_data.data = core.PaddleBuf(np.array(minputs))
infer_data.shape = minputs.shape()
infer_data.dtype = fluid.core.PaddleDType.FLOAT32
infer_label = fluid.core.PaddleTensor()
infer_label.data = fluid.core.PaddleBuf(np.array(label))
infer_data.dtype = core.PaddleDType.FLOAT32
infer_label = core.PaddleTensor()
infer_label.data = core.PaddleBuf(np.array(label))
infer_label.shape = label.shape
infer_label.dtype = fluid.core.PaddleDType.INT32
infer_label.dtype = core.PaddleDType.INT32
data.append([infer_data, infer_label])
warmup_data = data[:1]
inputs = data[1:]
......@@ -105,7 +109,7 @@ class TestLstmModelPTQ(unittest.TestCase):
use_analysis=False,
enable_ptq=False,
):
config = AnalysisConfig(model_path)
config = core.AnalysisConfig(model_path)
config.set_cpu_math_library_num_threads(num_threads)
if use_analysis:
config.disable_gpu()
......@@ -132,7 +136,7 @@ class TestLstmModelPTQ(unittest.TestCase):
use_analysis=False,
enable_ptq=False,
):
place = fluid.CPUPlace()
place = paddle.CPUPlace()
warmup_data, inputs = self.get_warmup_tensor(data_path, place)
warmup_data = [item[0] for item in warmup_data]
config = self.set_config(
......@@ -144,7 +148,7 @@ class TestLstmModelPTQ(unittest.TestCase):
enable_ptq,
)
predictor = create_paddle_predictor(config)
predictor = core.create_paddle_predictor(config)
data = [item[0] for item in inputs]
label = np.array([item[1] for item in inputs])
......@@ -197,7 +201,7 @@ class TestLstmModelPTQ(unittest.TestCase):
return hx_acc, ctc_acc, fps
def test_lstm_model(self):
if not fluid.core.is_compiled_with_mkldnn():
if not core.is_compiled_with_mkldnn():
return
fp32_model = test_case_args.fp32_model
......
......@@ -12,18 +12,19 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import sys
import argparse
import logging
import numpy as np
import os
import sys
import time
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
from paddle.fluid import core
from paddle.framework import core
from paddle.static.quantization import Quant2Int8MkldnnPass
paddle.enable_static()
......@@ -158,23 +159,26 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
target='quant',
):
assert target in ['quant', 'int8', 'fp32']
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
inference_scope = paddle.static.global_scope()
with paddle.static.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(model_path, exe)
] = paddle.fluid.io.load_inference_model(model_path, exe)
else:
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params'
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename='model',
params_filename='params',
)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
......@@ -296,7 +300,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
return set(map(int, string.split(',')))
def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn():
if not core.is_compiled_with_mkldnn():
return
quant_model_path = test_case_args.quant_model
......
......@@ -12,19 +12,20 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import sys
import argparse
import logging
import os
import struct
import numpy as np
import sys
import time
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantInt8MkldnnPass
from paddle.fluid import core
from paddle.framework import core
from paddle.static.quantization import QuantInt8MkldnnPass
paddle.enable_static()
......@@ -163,23 +164,26 @@ class QuantInt8ImageClassificationComparisonTest(unittest.TestCase):
skip_batch_num=0,
transform_to_int8=False,
):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
inference_scope = paddle.static.global_scope()
with paddle.static.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(model_path, exe)
] = paddle.fluid.io.load_inference_model(model_path, exe)
else:
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params'
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename='model',
params_filename='params',
)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
......@@ -298,7 +302,7 @@ class QuantInt8ImageClassificationComparisonTest(unittest.TestCase):
assert fp32_acc1 - int8_acc1 <= threshold
def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn():
if not core.is_compiled_with_mkldnn():
return
quant_model_path = test_case_args.quant_model
......
......@@ -12,15 +12,15 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import argparse
import os
import sys
import argparse
import unittest
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
from paddle.fluid import core
from paddle.framework import core
from paddle.static.quantization import Quant2Int8MkldnnPass
paddle.enable_static()
......@@ -93,35 +93,41 @@ def transform_and_save_int8_model(
debug=False,
quant_model_filename='',
quant_params_filename='',
save_model_filename="__model__",
save_model_filename="model",
save_params_filename=None,
):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
inference_scope = paddle.static.global_scope()
with paddle.static.scope_guard(inference_scope):
if not quant_model_filename:
if os.path.exists(os.path.join(original_path, '__model__')):
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(original_path, exe)
] = paddle.fluid.io.load_inference_model(original_path, exe)
else:
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(
original_path, exe, 'model', 'params'
] = paddle.static.load_inference_model(
original_path,
exe,
model_filename='model',
params_filename='params',
)
else:
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(
original_path, exe, quant_model_filename, quant_params_filename
] = paddle.static.load_inference_model(
original_path,
exe,
model_filename=quant_model_filename,
params_filename=quant_params_filename,
)
ops_to_quantize_set = set()
......@@ -147,15 +153,18 @@ def transform_and_save_int8_model(
)
graph = transform_to_mkldnn_int8_pass.apply(graph)
inference_program = graph.to_program()
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(
save_path,
feed_target_names,
with paddle.static.scope_guard(inference_scope):
path_prefix = os.path.join(save_path, save_model_filename)
feed_vars = [
inference_program.global_block().var(name)
for name in feed_target_names
]
paddle.static.save_inference_model(
path_prefix,
feed_vars,
fetch_targets,
exe,
inference_program,
model_filename=save_model_filename,
params_filename=save_params_filename,
executor=exe,
program=inference_program,
)
print(
"Success! INT8 model obtained from the Quant model can be found at {}\n".format(
......
......@@ -13,12 +13,13 @@
# limitations under the license.
import os
import numpy as np
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
from paddle.framework import core
paddle.enable_static()
......@@ -27,63 +28,68 @@ os.environ["CPU_NUM"] = "1"
def conv_block():
img = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
conv_pool_1 = fluid.nets.simple_img_conv_pool(
img = paddle.static.data(
name='image', shape=[-1, 1, 28, 28], dtype='float32'
)
label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
conv_out_1 = paddle.static.nn.conv2d(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu",
act='relu',
)
conv_pool_1 = paddle.nn.functional.max_pool2d(
conv_out_1, kernel_size=2, stride=2
)
conv_pool_1 = paddle.static.nn.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
conv_out_2 = paddle.static.nn.conv2d(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu",
num_filters=20,
act='relu',
)
conv_pool_2 = paddle.nn.functional.max_pool2d(
conv_out_2, kernel_size=2, stride=2
)
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
loss = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
prediction = paddle.static.nn.fc(
x=conv_pool_2, size=10, activation='softmax'
)
loss = paddle.nn.functional.cross_entropy(input=prediction, label=label)
avg_loss = paddle.mean(loss)
return [img, label], avg_loss
class TestGraph(unittest.TestCase):
def graph_apis(self, use_cuda=False, for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main, startup):
feeds, loss = conv_block()
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt = paddle.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False)
backup_graph = graph.clone()
self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes()))
build_strategy = fluid.BuildStrategy()
build_strategy = paddle.static.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
origin_binary = fluid.CompiledProgram(graph.graph).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy
)
backup_binary = fluid.CompiledProgram(
origin_binary = paddle.static.CompiledProgram(
graph.graph
).with_data_parallel(loss_name=loss.name, build_strategy=build_strategy)
backup_binary = paddle.static.CompiledProgram(
backup_graph.graph
).with_data_parallel(loss_name=loss.name, build_strategy=build_strategy)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup)
iters = 5
batch_size = 8
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size
)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
feeder = paddle.fluid.DataFeeder(feed_list=feeds, place=place)
def _train(binary):
for _ in range(iters):
......@@ -105,17 +111,29 @@ class TestGraph(unittest.TestCase):
var.set(var_array, place)
sum_before = np.sum(
np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor())
np.array(
paddle.static.global_scope()
.find_var('conv2d_1.w_0')
.get_tensor()
)
)
fluid.io._save_persistable_nodes(exe, checkponit_dir, graph)
_set_zero('conv2d_1.w_0', fluid.global_scope(), place)
paddle.fluid.io._save_persistable_nodes(exe, checkponit_dir, graph)
_set_zero('conv2d_1.w_0', paddle.static.global_scope(), place)
set_after = np.sum(
np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor())
np.array(
paddle.static.global_scope()
.find_var('conv2d_1.w_0')
.get_tensor()
)
)
self.assertEqual(set_after, 0)
fluid.io._load_persistable_nodes(exe, checkponit_dir, graph)
paddle.fluid.io._load_persistable_nodes(exe, checkponit_dir, graph)
sum_after = np.sum(
np.array(fluid.global_scope().find_var('conv2d_1.w_0').get_tensor())
np.array(
paddle.static.global_scope()
.find_var('conv2d_1.w_0')
.get_tensor()
)
)
self.assertEqual(sum_before, sum_after)
......@@ -144,7 +162,7 @@ class TestGraph(unittest.TestCase):
self.graph_apis(use_cuda=False, for_ci=True)
def test_graph_apis_cuda(self):
if fluid.core.is_compiled_with_cuda():
if core.is_compiled_with_cuda():
self.graph_apis(use_cuda=True, for_ci=True)
......
......@@ -13,38 +13,31 @@
# limitations under the license.
import os
import numpy as np
import random
import unittest
import logging
import warnings
import tempfile
import unittest
import numpy as np
from imperative_test_utils import fix_model_dict, train_lenet
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid import core
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.nn import Sequential
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, PReLU
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D
from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph import nn
from imperative_test_utils import fix_model_dict, train_lenet
from paddle.framework import core, set_flags
from paddle.nn import (
BatchNorm2D,
Conv2D,
Linear,
MaxPool2D,
Sequential,
Softmax,
)
from paddle.nn.layer import LeakyReLU, PReLU, ReLU, Sigmoid
from paddle.quantization import ImperativeQuantAware
paddle.enable_static()
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
set_flags({"FLAGS_cudnn_deterministic": True})
def get_vaild_warning_num(warning, w):
......@@ -55,18 +48,18 @@ def get_vaild_warning_num(warning, w):
return num
class ImperativeLenet(fluid.dygraph.Layer):
class ImperativeLenet(paddle.nn.Layer):
def __init__(self, num_classes=10):
super().__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
conv2d_w1_attr = paddle.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = paddle.ParamAttr(name="conv2d_w_2")
fc_w1_attr = paddle.ParamAttr(name="fc_w_1")
fc_w2_attr = paddle.ParamAttr(name="fc_w_2")
fc_w3_attr = paddle.ParamAttr(name="fc_w_3")
conv2d_b2_attr = paddle.ParamAttr(name="conv2d_b_2")
fc_b1_attr = paddle.ParamAttr(name="fc_b_1")
fc_b2_attr = paddle.ParamAttr(name="fc_b_2")
fc_b3_attr = paddle.ParamAttr(name="fc_b_3")
self.features = Sequential(
Conv2D(
in_channels=1,
......@@ -121,7 +114,7 @@ class ImperativeLenet(fluid.dygraph.Layer):
def forward(self, inputs):
x = self.features(inputs)
x = paddle.flatten(x, 1, -1)
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
......@@ -152,8 +145,8 @@ class TestImperativeOutSclae(unittest.TestCase):
with fluid.dygraph.guard():
np.random.seed(seed)
fluid.default_main_program().random_seed = seed
fluid.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed
lenet = ImperativeLenet()
lenet = fix_model_dict(lenet)
......@@ -162,8 +155,8 @@ class TestImperativeOutSclae(unittest.TestCase):
reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32, drop_last=True
)
adam = AdamOptimizer(
learning_rate=lr, parameter_list=lenet.parameters()
adam = paddle.optimizer.Adam(
learning_rate=lr, parameters=lenet.parameters()
)
loss_list = train_lenet(lenet, reader, adam)
lenet.eval()
......@@ -186,8 +179,8 @@ class TestImperativeOutSclae(unittest.TestCase):
reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32, drop_last=True
)
adam = AdamOptimizer(
learning_rate=lr, parameter_list=lenet.parameters()
adam = paddle.optimizer.Adam(
learning_rate=lr, parameters=lenet.parameters()
)
loss_list = train_lenet(lenet, reader, adam)
lenet.eval()
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
......@@ -12,29 +12,32 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import logging
import os
import numpy as np
import random
import shutil
import tempfile
import time
import unittest
import copy
import logging
import tempfile
import paddle.nn as nn
import paddle
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import *
from paddle.fluid.log_helper import get_logger
from paddle.dataset.common import download
import numpy as np
from imperative_test_utils import (
fix_model_dict,
ImperativeLenet,
ImperativeLinearBn,
ImperativeLinearBn_hook,
)
from imperative_test_utils import ImperativeLinearBn_hook
import paddle
import paddle.nn as nn
from paddle.dataset.common import download
from paddle.fluid.framework import _test_eager_guard
from paddle.quantization import (
AbsmaxQuantizer,
HistQuantizer,
ImperativePTQ,
KLQuantizer,
PerChannelAbsmaxQuantizer,
PTQConfig,
)
from paddle.static.log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
......@@ -149,8 +152,8 @@ class TestImperativePTQ(unittest.TestCase):
label = paddle.to_tensor(y_data)
out = model(img)
acc_top1 = paddle.static.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.static.accuracy(input=out, label=label, k=5)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
eval_acc_top1_list.append(float(acc_top1.numpy()))
if batch_id % 50 == 0:
......@@ -207,7 +210,7 @@ class TestImperativePTQ(unittest.TestCase):
break
return top1_correct_num / total_num
def test_ptq(self):
def func_ptq(self):
start_time = time.time()
self.set_vars()
......@@ -265,9 +268,14 @@ class TestImperativePTQ(unittest.TestCase):
end_time = time.time()
print("total time: %ss \n" % (end_time - start_time))
def test_ptq(self):
with _test_eager_guard():
self.func_ptq()
self.func_ptq()
class TestImperativePTQfuse(TestImperativePTQ):
def test_ptq(self):
def func_ptq(self):
start_time = time.time()
self.set_vars()
......@@ -336,6 +344,11 @@ class TestImperativePTQfuse(TestImperativePTQ):
end_time = time.time()
print("total time: %ss \n" % (end_time - start_time))
def test_ptq(self):
with _test_eager_guard():
self.func_ptq()
self.func_ptq()
class TestImperativePTQHist(TestImperativePTQ):
def set_vars(self):
......
......@@ -12,34 +12,34 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import logging
import os
import numpy as np
import random
import time
import tempfile
import unittest
import logging
import numpy as np
from imperative_test_utils import ImperativeLenet, fix_model_dict
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.nn import Sequential
from paddle.nn import Linear, Conv2D, Softmax, Conv2DTranspose
from paddle.fluid.log_helper import get_logger
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.framework import core, set_flags
from paddle.nn import Conv2D, Conv2DTranspose
from paddle.nn.quant.quant_layers import (
QuantizedConv2D,
QuantizedConv2DTranspose,
)
from imperative_test_utils import fix_model_dict, ImperativeLenet
from paddle.optimizer import Adam
from paddle.quantization import ImperativeQuantAware
from paddle.static.log_helper import get_logger
INFER_MODEL_SUFFIX = ".pdmodel"
INFER_PARAMS_SUFFIX = ".pdiparams"
paddle.enable_static()
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
......@@ -84,7 +84,7 @@ class TestImperativeQat(unittest.TestCase):
)
quant_conv1 = QuantizedConv2D(conv1)
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
quant_conv1(fluid.dygraph.to_variable(data))
quant_conv1(paddle.to_tensor(data))
conv_transpose = Conv2DTranspose(4, 6, (3, 3))
quant_conv_transpose = QuantizedConv2DTranspose(conv_transpose)
......@@ -95,15 +95,13 @@ class TestImperativeQat(unittest.TestCase):
seed = 1
np.random.seed(seed)
fluid.default_main_program().random_seed = seed
fluid.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed
lenet = ImperativeLenet()
lenet = fix_model_dict(lenet)
imperative_qat.quantize(lenet)
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=lenet.parameters()
)
adam = Adam(learning_rate=0.001, parameters=lenet.parameters())
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=32, drop_last=True
......@@ -125,10 +123,10 @@ class TestImperativeQat(unittest.TestCase):
.reshape(-1, 1)
)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
out = lenet(img)
acc = paddle.static.accuracy(out, label)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.cross_entropy(
out, label, reduction='none', use_softmax=False
)
......@@ -157,14 +155,14 @@ class TestImperativeQat(unittest.TestCase):
.reshape(-1, 1)
)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
out = lenet(img)
acc_top1 = paddle.static.accuracy(
acc_top1 = paddle.metric.accuracy(
input=out, label=label, k=1
)
acc_top5 = paddle.static.accuracy(
acc_top5 = paddle.metric.accuracy(
input=out, label=label, k=5
)
......@@ -197,11 +195,11 @@ class TestImperativeQat(unittest.TestCase):
y_data = (
np.array([x[1] for x in data]).astype('int64').reshape(-1, 1)
)
test_img = fluid.dygraph.to_variable(test_data)
label = fluid.dygraph.to_variable(y_data)
test_img = paddle.to_tensor(test_data)
label = paddle.to_tensor(y_data)
lenet.eval()
fp32_out = lenet(test_img)
fp32_acc = paddle.static.accuracy(fp32_out, label).numpy()
fp32_acc = paddle.metric.accuracy(fp32_out, label).numpy()
with tempfile.TemporaryDirectory(prefix="qat_save_path_") as tmpdir:
# save inference quantized model
......@@ -220,13 +218,13 @@ class TestImperativeQat(unittest.TestCase):
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe = paddle.static.Executor(place)
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(
dirname=tmpdir,
] = paddle.static.load_inference_model(
tmpdir,
executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX,
......@@ -237,8 +235,8 @@ class TestImperativeQat(unittest.TestCase):
fetch_list=fetch_targets,
)
paddle.disable_static()
quant_out = fluid.dygraph.to_variable(quant_out)
quant_acc = paddle.static.accuracy(quant_out, label).numpy()
quant_out = paddle.to_tensor(quant_out)
quant_acc = paddle.metric.accuracy(quant_out, label).numpy()
paddle.enable_static()
delta_value = fp32_acc - quant_acc
self.assertLessEqual(delta_value, self.diff_threshold)
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
......@@ -12,25 +12,25 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import logging
import os
import numpy as np
import random
import shutil
import tempfile
import time
import unittest
import logging
import tempfile
import numpy as np
from imperative_test_utils import ImperativeLenet
import paddle
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.log_helper import get_logger
from paddle.dataset.common import download
from imperative_test_utils import fix_model_dict, ImperativeLenet
from paddle.framework import set_flags
from paddle.quantization import ImperativeQuantAware
from paddle.static.log_helper import get_logger
os.environ["CPU_NUM"] = "1"
if paddle.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
......@@ -117,7 +117,7 @@ class TestImperativeQatAmp(unittest.TestCase):
if use_amp:
with paddle.amp.auto_cast():
out = model(img)
acc = paddle.static.accuracy(out, label)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.cross_entropy(
out, label, reduction='none', use_softmax=False
)
......@@ -129,7 +129,7 @@ class TestImperativeQatAmp(unittest.TestCase):
adam.clear_gradients()
else:
out = model(img)
acc = paddle.static.accuracy(out, label)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.cross_entropy(
out, label, reduction='none', use_softmax=False
)
......@@ -170,8 +170,8 @@ class TestImperativeQatAmp(unittest.TestCase):
with paddle.amp.auto_cast(use_amp):
out = model(img)
acc_top1 = paddle.static.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.static.accuracy(input=out, label=label, k=5)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
acc_top1_list.append(float(acc_top1.numpy()))
if batch_id % 100 == 0:
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
......@@ -13,27 +13,18 @@
# limitations under the license.
import os
import numpy as np
import random
import unittest
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.log_helper import get_logger
from test_imperative_qat import TestImperativeQat
import paddle
from paddle.framework import core, set_flags
paddle.enable_static()
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
set_flags({"FLAGS_cudnn_deterministic": True})
class TestImperativeQatChannelWise(TestImperativeQat):
......
......@@ -13,27 +13,18 @@
# limitations under the license.
import os
import numpy as np
import random
import unittest
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.log_helper import get_logger
from test_imperative_qat import TestImperativeQat
import paddle
from paddle.framework import core, set_flags
paddle.enable_static()
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
set_flags({"FLAGS_cudnn_deterministic": True})
class TestImperativeQatfuseBN(TestImperativeQat):
......
......@@ -12,57 +12,53 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import logging
import os
import numpy as np
import random
import time
import tempfile
import unittest
import logging
import numpy as np
from imperative_test_utils import fix_model_dict
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import (
SGDOptimizer,
AdamOptimizer,
MomentumOptimizer,
from paddle.framework import core, set_flags
from paddle.nn import (
BatchNorm2D,
Conv2D,
LeakyReLU,
Linear,
MaxPool2D,
PReLU,
ReLU,
Sequential,
Sigmoid,
Softmax,
)
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.nn import Sequential
from paddle.nn import ReLU, ReLU6, LeakyReLU, Sigmoid, Softmax, PReLU
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D
from paddle.fluid.log_helper import get_logger
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.quant.quant_layers import (
QuantizedConv2D,
QuantizedConv2DTranspose,
)
from imperative_test_utils import fix_model_dict
from paddle.quantization import ImperativeQuantAware
from paddle.static.log_helper import get_logger
paddle.enable_static()
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
class ImperativeLenet(fluid.dygraph.Layer):
class ImperativeLenet(paddle.nn.Layer):
def __init__(self, num_classes=10):
super().__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
conv2d_w1_attr = paddle.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = paddle.ParamAttr(name="conv2d_w_2")
fc_w1_attr = paddle.ParamAttr(name="fc_w_1")
fc_w2_attr = paddle.ParamAttr(name="fc_w_2")
fc_w3_attr = paddle.ParamAttr(name="fc_w_3")
conv2d_b2_attr = paddle.ParamAttr(name="conv2d_b_2")
fc_b1_attr = paddle.ParamAttr(name="fc_b_1")
fc_b2_attr = paddle.ParamAttr(name="fc_b_2")
fc_b3_attr = paddle.ParamAttr(name="fc_b_3")
self.features = Sequential(
Conv2D(
in_channels=1,
......@@ -116,7 +112,7 @@ class ImperativeLenet(fluid.dygraph.Layer):
def forward(self, inputs):
x = self.features(inputs)
x = paddle.flatten(x, 1, -1)
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
......@@ -139,14 +135,14 @@ class TestImperativeQatLSQ(unittest.TestCase):
seed = 100
np.random.seed(seed)
fluid.default_main_program().random_seed = seed
fluid.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed
paddle.disable_static()
lenet = ImperativeLenet()
lenet = fix_model_dict(lenet)
imperative_qat.quantize(lenet)
optimizer = MomentumOptimizer(
learning_rate=0.1, parameter_list=lenet.parameters(), momentum=0.9
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1, parameters=lenet.parameters(), momentum=0.9
)
train_reader = paddle.batch(
......@@ -166,10 +162,10 @@ class TestImperativeQatLSQ(unittest.TestCase):
.reshape(-1, 1)
)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
out = lenet(img)
acc = paddle.static.accuracy(out, label)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.cross_entropy(
out, label, reduction='none', use_softmax=False
)
......@@ -199,14 +195,14 @@ class TestImperativeQatLSQ(unittest.TestCase):
.astype('int64')
.reshape(-1, 1)
)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
out = lenet(img)
acc_top1 = paddle.static.accuracy(
acc_top1 = paddle.metric.accuracy(
input=out, label=label, k=1
)
acc_top5 = paddle.static.accuracy(
acc_top5 = paddle.metric.accuracy(
input=out, label=label, k=5
)
......
......@@ -12,57 +12,55 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import logging
import os
import numpy as np
import random
import time
import tempfile
import unittest
import logging
import numpy as np
from imperative_test_utils import fix_model_dict
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import (
SGDOptimizer,
AdamOptimizer,
MomentumOptimizer,
from paddle.framework import core, set_flags
from paddle.nn import (
BatchNorm2D,
Conv2D,
LeakyReLU,
Linear,
MaxPool2D,
PReLU,
ReLU,
Sequential,
Sigmoid,
Softmax,
)
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.nn import Sequential
from paddle.nn import ReLU, ReLU6, LeakyReLU, Sigmoid, Softmax, PReLU
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D
from paddle.fluid.log_helper import get_logger
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.quant.quant_layers import (
QuantizedConv2D,
QuantizedMatmul,
)
from imperative_test_utils import fix_model_dict
from paddle.nn.quant.quant_layers import QuantizedMatmul
from paddle.optimizer import Momentum
from paddle.quantization import ImperativeQuantAware
from paddle.static.log_helper import get_logger
paddle.enable_static()
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
class ImperativeLenet(fluid.dygraph.Layer):
class ImperativeLenet(paddle.nn.Layer):
def __init__(self, num_classes=10):
super().__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
conv2d_w1_attr = paddle.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = paddle.ParamAttr(name="conv2d_w_2")
fc_w1_attr = paddle.ParamAttr(name="fc_w_1")
fc_w2_attr = paddle.ParamAttr(name="fc_w_2")
fc_w3_attr = paddle.ParamAttr(name="fc_w_3")
conv2d_b2_attr = paddle.ParamAttr(name="conv2d_b_2")
fc_b1_attr = paddle.ParamAttr(name="fc_b_1")
fc_b2_attr = paddle.ParamAttr(name="fc_b_2")
fc_b3_attr = paddle.ParamAttr(name="fc_b_3")
self.features = Sequential(
Conv2D(
in_channels=1,
......@@ -140,15 +138,15 @@ class TestImperativeQatMatmul(unittest.TestCase):
seed = 100
np.random.seed(seed)
fluid.default_main_program().random_seed = seed
fluid.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed
paddle.disable_static()
lenet = ImperativeLenet()
lenet = fix_model_dict(lenet)
imperative_qat.quantize(lenet)
optimizer = MomentumOptimizer(
learning_rate=0.1, parameter_list=lenet.parameters(), momentum=0.9
optimizer = Momentum(
learning_rate=0.1, parameters=lenet.parameters(), momentum=0.9
)
train_reader = paddle.batch(
......@@ -168,18 +166,18 @@ class TestImperativeQatMatmul(unittest.TestCase):
.reshape(-1, 1)
)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
out = lenet(img)
acc = paddle.static.accuracy(out, label)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.cross_entropy(
out, label, reduction='none', use_softmax=False
)
avg_loss = paddle.mean(loss)
avg_loss.backward()
optimizer.minimize(avg_loss)
lenet.clear_gradients()
optimizer.step()
optimizer.clear_grad()
if batch_id % 100 == 0:
_logger.info(
......@@ -201,14 +199,14 @@ class TestImperativeQatMatmul(unittest.TestCase):
.astype('int64')
.reshape(-1, 1)
)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
out = lenet(img)
acc_top1 = paddle.static.accuracy(
acc_top1 = paddle.metric.accuracy(
input=out, label=label, k=1
)
acc_top5 = paddle.static.accuracy(
acc_top5 = paddle.metric.accuracy(
input=out, label=label, k=5
)
......
......@@ -12,20 +12,19 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import logging
import os
import numpy as np
import random
import unittest
import logging
import numpy as np
import paddle
import paddle.nn as nn
from paddle.optimizer import Adam
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import Sequential
from paddle.nn import Linear
from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose
from paddle.fluid.log_helper import get_logger
from paddle.optimizer import Adam
from paddle.quantization import ImperativeQuantAware
from paddle.static.log_helper import get_logger
os.environ["CPU_NUM"] = "1"
......@@ -110,7 +109,7 @@ class ModelForConv2dT(nn.Layer):
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Conv2DTranspose(4, 6, (3, 3))
self.fc = Linear(600, num_classes)
self.fc = nn.Linear(in_features=600, out_features=num_classes)
def forward(self, inputs):
x = self.features(inputs)
......@@ -123,28 +122,28 @@ class ImperativeLenet(paddle.nn.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'):
super().__init__()
self.features = Sequential(
paddle.nn.Conv2D(
nn.Conv2D(
in_channels=1,
out_channels=6,
kernel_size=3,
stride=1,
padding=1,
),
paddle.nn.MaxPool2D(kernel_size=2, stride=2),
paddle.nn.Conv2D(
nn.MaxPool2D(kernel_size=2, stride=2),
nn.Conv2D(
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0,
),
paddle.nn.MaxPool2D(kernel_size=2, stride=2),
nn.MaxPool2D(kernel_size=2, stride=2),
)
self.fc = Sequential(
Linear(400, 120),
Linear(120, 84),
Linear(84, num_classes),
nn.Linear(in_features=400, out_features=120),
nn.Linear(in_features=120, out_features=84),
nn.Linear(in_features=84, out_features=num_classes),
)
def forward(self, inputs):
......@@ -160,7 +159,7 @@ class TestUserDefinedActPreprocess(unittest.TestCase):
_logger.info("test act_preprocess")
self.imperative_qat = ImperativeQuantAware(act_preprocess_layer=PACT)
def test_quant_aware_training(self):
def func_quant_aware_training(self):
imperative_qat = self.imperative_qat
seed = 1
np.random.seed(seed)
......@@ -170,8 +169,8 @@ class TestUserDefinedActPreprocess(unittest.TestCase):
fixed_state = {}
param_init_map = {}
for name, param in lenet.named_parameters():
p_shape = param.numpy().shape
p_value = param.numpy()
p_shape = np.array(param).shape
p_value = np.array(param)
if name.endswith("bias"):
value = np.zeros_like(p_value).astype('float32')
else:
......@@ -217,8 +216,8 @@ class TestUserDefinedActPreprocess(unittest.TestCase):
loss = nn.functional.loss.cross_entropy(out, label)
avg_loss = paddle.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
model.clear_gradients()
adam.step()
adam.clear_grad()
if batch_id % 50 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".format(
......@@ -262,6 +261,11 @@ class TestUserDefinedActPreprocess(unittest.TestCase):
train(lenet)
test(lenet)
def test_quant_aware_training(self):
with _test_eager_guard():
self.func_quant_aware_training()
self.func_quant_aware_training()
class TestUserDefinedWeightPreprocess(TestUserDefinedActPreprocess):
def setUp(self):
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
......@@ -13,34 +13,25 @@
# limitations under the license.
import os
import numpy as np
import random
import unittest
import logging
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid import core
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, ReLU6
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm
from paddle.fluid.log_helper import get_logger
import numpy as np
from imperative_test_utils import (
ImperativeLenetWithSkipQuant,
fix_model_dict,
train_lenet,
ImperativeLenetWithSkipQuant,
)
import paddle
from paddle.framework import core, set_flags
from paddle.optimizer import Adam
from paddle.quantization import ImperativeQuantAware
INFER_MODEL_SUFFIX = ".pdmodel"
INFER_PARAMS_SUFFIX = ".pdiparams"
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
set_flags({"FLAGS_cudnn_deterministic": True})
class TestImperativeOutSclae(unittest.TestCase):
......@@ -60,9 +51,7 @@ class TestImperativeOutSclae(unittest.TestCase):
lenet = fix_model_dict(lenet)
qat.quantize(lenet)
adam = AdamOptimizer(
learning_rate=lr, parameter_list=lenet.parameters()
)
adam = Adam(learning_rate=lr, parameters=lenet.parameters())
dynamic_loss_rec = []
lenet.train()
loss_list = train_lenet(lenet, reader, adam)
......@@ -88,14 +77,14 @@ class TestImperativeOutSclae(unittest.TestCase):
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe = paddle.static.Executor(place)
[
inference_program,
feed_target_names,
fetch_targets,
] = fluid.io.load_inference_model(
dirname=save_dir,
] = paddle.static.load_inference_model(
save_dir,
executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX,
......
......@@ -13,12 +13,12 @@
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
import paddle.nn.quant.quant_layers as quant_layers
from paddle.framework import core
paddle.enable_static()
......@@ -38,23 +38,23 @@ def init_data(batch_size=32, img_shape=[784], label_range=9):
class TestMovingAverageAbsMaxScaleOp(unittest.TestCase):
def check_backward(self, use_cuda):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
image = fluid.layers.data(
name='image', shape=[784], dtype='float32'
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
image = paddle.static.data(
name='image', shape=[-1, 784], dtype='float32'
)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
fc_tmp = fluid.layers.fc(image, size=10, act='softmax')
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
fc_tmp = paddle.static.nn.fc(image, size=10, activation='softmax')
out_scale = quant_layers.MovingAverageAbsMaxScale(
name=fc_tmp.name, dtype=fc_tmp.dtype
)
fc_tmp_1 = out_scale(fc_tmp)
cross_entropy = paddle.nn.functional.softmax_with_cross_entropy(
fc_tmp, label
)
cross_entropy = paddle.nn.functional.cross_entropy(fc_tmp, label)
loss = paddle.mean(cross_entropy)
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd = paddle.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
moving_average_abs_max_scale_ops = [
......@@ -66,13 +66,13 @@ class TestMovingAverageAbsMaxScaleOp(unittest.TestCase):
len(moving_average_abs_max_scale_ops) == 1
), "The number of moving_average_abs_max_scale_ops should be 1."
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_program)
binary = fluid.compiler.CompiledProgram(
main_program
).with_data_parallel(loss_name=loss.name)
binary = paddle.static.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name
)
img, label = init_data()
feed_dict = {"image": img, "label": label}
......
......@@ -11,21 +11,20 @@
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import time
import sys
import random
import math
import functools
import contextlib
import struct
import sys
import tempfile
import time
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.dataset.common import download
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.static.quantization import PostTrainingQuantization
paddle.enable_static()
......@@ -133,15 +132,27 @@ class TestPostTrainingQuantization(unittest.TestCase):
return reader
def run_program(self, model_path, data_path, infer_iterations):
def run_program(
self,
model_path,
model_filename,
params_filename,
data_path,
infer_iterations,
):
print("test model path:" + model_path)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
[
infer_program,
feed_dict,
fetch_targets,
] = fluid.io.load_inference_model(model_path, exe)
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename=model_filename,
params_filename=params_filename,
)
val_reader = self.get_simple_reader(data_path, place)
......@@ -176,6 +187,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def generate_quantized_model(
self,
model_path,
model_filename,
params_filename,
data_path,
algo="KL",
round_type="round",
......@@ -188,14 +201,16 @@ class TestPostTrainingQuantization(unittest.TestCase):
onnx_format=False,
):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
scope = paddle.static.global_scope()
batch_generator = self.get_batch_reader(data_path, place)
ptq = PostTrainingQuantization(
executor=exe,
model_dir=model_path,
model_filename=model_filename,
params_filename=params_filename,
batch_generator=batch_generator,
batch_nums=batch_nums,
algo=algo,
......@@ -214,6 +229,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def run_test(
self,
model_name,
model_filename,
params_filename,
model_url,
model_md5,
data_name,
......@@ -242,7 +259,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
)
)
(fp32_latency, fp32_acc) = self.run_program(
fp32_model_path, data_path, infer_iterations
fp32_model_path,
model_filename,
params_filename,
data_path,
infer_iterations,
)
print(
......@@ -252,6 +273,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
)
self.generate_quantized_model(
fp32_model_path,
model_filename,
params_filename,
data_path,
algo,
round_type,
......@@ -270,7 +293,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
)
)
(int8_latency, int8_acc) = self.run_program(
self.int8_model_path, data_path, infer_iterations
self.int8_model_path,
'model.pdmodel',
'model.pdiparams',
data_path,
infer_iterations,
)
print("---Post training quantization of {} method---".format(algo))
......@@ -293,8 +320,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
def test_post_training_avg(self):
model_name = "nlp_lstm_fp32_model"
model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz"
model_md5 = "519b8eeac756e7b4b7bcb2868e880452"
model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model_combined.tar.gz"
model_md5 = "5b47cd7ba2afcf24120d9727ed3f05a7"
data_name = "quant_lstm_input_data"
data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5 = "add84c754e9b792fea1fbd728d134ab7"
......@@ -309,6 +336,8 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
quant_iterations = 10
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
model_url,
model_md5,
data_name,
......@@ -329,8 +358,8 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
def not_test_post_training_avg_onnx_format(self):
model_name = "nlp_lstm_fp32_model"
model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz"
model_md5 = "519b8eeac756e7b4b7bcb2868e880452"
model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model_combined.tar.gz"
model_md5 = "5b47cd7ba2afcf24120d9727ed3f05a7"
data_name = "quant_lstm_input_data"
data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5 = "add84c754e9b792fea1fbd728d134ab7"
......@@ -346,6 +375,8 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
onnx_format = True
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
model_url,
model_md5,
data_name,
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
......@@ -11,20 +11,18 @@
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import time
import sys
import random
import math
import functools
import sys
import tempfile
import contextlib
import time
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.dataset.common import download
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.dataset.common import md5file
from paddle.static.quantization import PostTrainingQuantization
paddle.enable_static()
......@@ -38,12 +36,13 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.int8_model_path = os.path.join(
self.root_path.name, "post_training_quantization"
)
self.download_path = 'int8/download'
self.cache_folder = os.path.expanduser(
'~/.cache/paddle/dataset/' + self.download_path
self.download_path = f'download_model_{time.time()}'
self.cache_folder = os.path.join(
self.root_path.name, self.download_path
)
try:
os.system("mkdir -p " + self.int8_model_path)
os.system("mkdir -p " + self.cache_folder)
except Exception as e:
print(
"Failed to create {} due to {}".format(
......@@ -62,25 +61,110 @@ class TestPostTrainingQuantization(unittest.TestCase):
)
os.system(cmd)
def download(self, url, dirname, md5sum, save_name=None):
import shutil
import requests
filename = os.path.join(
dirname, url.split('/')[-1] if save_name is None else save_name
)
if os.path.exists(filename) and md5file(filename) == md5sum:
return filename
retry = 0
retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if os.path.exists(filename):
sys.stderr.write(
"file %s md5 %s\n" % (md5file(filename), md5sum)
)
if retry < retry_limit:
retry += 1
else:
raise RuntimeError(
"Cannot download {0} within retry limit {1}".format(
url, retry_limit
)
)
sys.stderr.write(
"Cache file %s not found, downloading %s \n" % (filename, url)
)
sys.stderr.write("Begin to download\n")
try:
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(filename, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(filename, 'wb') as f:
chunk_size = 4096
total_length = int(total_length)
total_iter = total_length / chunk_size + 1
log_interval = (
total_iter // 20 if total_iter > 20 else 1
)
log_index = 0
bar = paddle.hapi.progressbar.ProgressBar(
total_iter, name='item'
)
for data in r.iter_content(chunk_size=chunk_size):
f.write(data)
log_index += 1
bar.update(log_index, {})
if log_index % log_interval == 0:
bar.update(log_index)
except Exception as e:
# re-try
continue
sys.stderr.write("\nDownload finished\n")
sys.stdout.flush()
return filename
def download_model(self, data_url, data_md5, folder_name):
download(data_url, self.download_path, data_md5)
self.download(data_url, self.cache_folder, data_md5)
os.system(f'wget -q {data_url}')
file_name = data_url.split('/')[-1]
zip_path = os.path.join(self.cache_folder, file_name)
print('Data is downloaded at {0}'.format(zip_path))
print(
'Data is downloaded at {0}. File exists: {1}'.format(
zip_path, os.path.exists(zip_path)
)
)
data_cache_folder = os.path.join(self.cache_folder, folder_name)
self.cache_unzipping(data_cache_folder, zip_path)
return data_cache_folder
def run_program(self, model_path, batch_size, infer_iterations):
print("test model path:" + model_path)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
def run_program(
self,
model_path,
model_filename,
params_filename,
batch_size,
infer_iterations,
):
print(
"test model path: {}. File exists: {}".format(
model_path, os.path.exists(model_path)
)
)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
[
infer_program,
feed_dict,
fetch_targets,
] = fluid.io.load_inference_model(model_path, exe)
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename=model_filename,
params_filename=params_filename,
)
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size)
img_shape = [1, 28, 28]
......@@ -119,6 +203,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def generate_quantized_model(
self,
model_path,
model_filename,
params_filename,
algo="KL",
round_type="round",
quantizable_op_type=["conv2d"],
......@@ -132,13 +218,15 @@ class TestPostTrainingQuantization(unittest.TestCase):
bias_correction=False,
):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_reader = paddle.dataset.mnist.train()
ptq = PostTrainingQuantization(
executor=exe,
model_dir=model_path,
model_filename=model_filename,
params_filename=params_filename,
sample_generator=val_reader,
batch_size=batch_size,
batch_nums=batch_nums,
......@@ -158,6 +246,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def run_test(
self,
model_name,
model_filename,
params_filename,
data_url,
data_md5,
algo,
......@@ -183,8 +273,13 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_name, infer_iterations * batch_size
)
)
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
origin_model_path, batch_size, infer_iterations
origin_model_path,
model_filename,
params_filename,
batch_size,
infer_iterations,
)
print(
......@@ -194,6 +289,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
)
self.generate_quantized_model(
origin_model_path,
model_filename,
params_filename,
algo,
round_type,
quantizable_op_type,
......@@ -213,7 +310,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
)
)
(int8_throughput, int8_latency, int8_acc1) = self.run_program(
self.int8_model_path, batch_size, infer_iterations
self.int8_model_path,
'model.pdmodel',
'model.pdiparams',
batch_size,
infer_iterations,
)
print("---Post training quantization of {} method---".format(algo))
......@@ -236,10 +337,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
def test_post_training_kl(self):
model_name = "mnist_model"
data_url = (
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
)
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model_combined.tar.gz"
data_md5 = "a49251d3f555695473941e5a725c6014"
algo = "KL"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
......@@ -252,6 +351,8 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
quant_iterations = 5
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
data_url,
data_md5,
algo,
......@@ -270,10 +371,8 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
class TestPostTraininghistForMnist(TestPostTrainingQuantization):
def test_post_training_hist(self):
model_name = "mnist_model"
data_url = (
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
)
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model_combined.tar.gz"
data_md5 = "a49251d3f555695473941e5a725c6014"
algo = "hist"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
......@@ -286,6 +385,8 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
quant_iterations = 5
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
data_url,
data_md5,
algo,
......@@ -304,10 +405,8 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self):
model_name = "mnist_model"
data_url = (
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
)
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model_combined.tar.gz"
data_md5 = "a49251d3f555695473941e5a725c6014"
algo = "mse"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
......@@ -320,6 +419,8 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
quant_iterations = 5
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
data_url,
data_md5,
algo,
......@@ -338,10 +439,8 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self):
model_name = "mnist_model"
data_url = (
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
)
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model_combined.tar.gz"
data_md5 = "a49251d3f555695473941e5a725c6014"
algo = "emd"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
......@@ -354,6 +453,8 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
quant_iterations = 5
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
data_url,
data_md5,
algo,
......@@ -372,10 +473,8 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
def test_post_training_avg(self):
model_name = "mnist_model"
data_url = (
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
)
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model_combined.tar.gz"
data_md5 = "a49251d3f555695473941e5a725c6014"
algo = "avg"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
......@@ -388,6 +487,8 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
quant_iterations = 5
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
data_url,
data_md5,
algo,
......@@ -406,10 +507,8 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
def test_post_training_abs_max(self):
model_name = "mnist_model"
data_url = (
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
)
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model_combined.tar.gz"
data_md5 = "a49251d3f555695473941e5a725c6014"
algo = "abs_max"
round_type = "round"
quantizable_op_type = ["conv2d", "mul"]
......@@ -422,6 +521,8 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
quant_iterations = 10
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
data_url,
data_md5,
algo,
......@@ -440,10 +541,8 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self):
model_name = "mnist_model"
data_url = (
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
)
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model_combined.tar.gz"
data_md5 = "a49251d3f555695473941e5a725c6014"
algo = "mse"
round_type = "adaround"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
......@@ -457,6 +556,8 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
bias_correction = True
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
data_url,
data_md5,
algo,
......@@ -476,10 +577,8 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
def test_post_training_kl(self):
model_name = "mnist_model"
data_url = (
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
)
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model_combined.tar.gz"
data_md5 = "a49251d3f555695473941e5a725c6014"
algo = "KL"
round_type = "adaround"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
......@@ -492,6 +591,8 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
quant_iterations = 5
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
data_url,
data_md5,
algo,
......@@ -510,10 +611,8 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
def test_post_training_mse_onnx_format(self):
model_name = "mnist_model"
data_url = (
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
)
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model_combined.tar.gz"
data_md5 = "a49251d3f555695473941e5a725c6014"
algo = "mse"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
......@@ -527,6 +626,8 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
quant_iterations = 5
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
data_url,
data_md5,
algo,
......@@ -548,10 +649,8 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
):
def test_post_training_mse_onnx_format_full_quant(self):
model_name = "mnist_model"
data_url = (
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
)
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model_combined.tar.gz"
data_md5 = "a49251d3f555695473941e5a725c6014"
algo = "mse"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
......@@ -565,6 +664,8 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
quant_iterations = 5
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
data_url,
data_md5,
algo,
......@@ -584,10 +685,8 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
def test_post_training_avg_skip_op(self):
model_name = "mnist_model"
data_url = (
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
)
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model_combined.tar.gz"
data_md5 = "a49251d3f555695473941e5a725c6014"
algo = "avg"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
......@@ -601,6 +700,8 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
skip_tensor_list = ["fc_0.w_0"]
self.run_test(
model_name,
'model.pdmodel',
'model.pdiparams',
data_url,
data_md5,
algo,
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
......@@ -11,21 +11,20 @@
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import functools
import os
import time
import sys
import random
import math
import functools
import contextlib
import sys
import tempfile
import time
import unittest
import numpy as np
from PIL import Image, ImageEnhance
from PIL import Image
import paddle
import paddle.fluid as fluid
from paddle.dataset.common import download
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.static.quantization import PostTrainingQuantization
paddle.enable_static()
......@@ -52,7 +51,7 @@ def resize_short(img, target_size):
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
if center is True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
......@@ -201,19 +200,26 @@ class TestPostTrainingQuantization(unittest.TestCase):
def download_model(self):
pass
def run_program(self, model_path, batch_size, infer_iterations):
def run_program(
self,
model_path,
model_filename,
params_filename,
batch_size,
infer_iterations,
):
image_shape = [3, 224, 224]
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
[
infer_program,
feed_dict,
fetch_targets,
] = fluid.io.load_inference_model(
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
model_filename=model_filename,
params_filename=params_filename,
)
val_reader = paddle.batch(val(), batch_size)
iterations = infer_iterations
......@@ -260,6 +266,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def generate_quantized_model(
self,
model_path,
model_filename,
params_filename,
quantizable_op_type,
batch_size,
algo="KL",
......@@ -278,17 +286,16 @@ class TestPostTrainingQuantization(unittest.TestCase):
)
sys.exit(-1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_reader = val()
ptq = PostTrainingQuantization(
executor=exe,
sample_generator=val_reader,
model_dir=model_path,
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
model_filename=model_filename,
params_filename=params_filename,
batch_size=batch_size,
batch_nums=batch_nums,
algo=algo,
......@@ -309,6 +316,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def run_test(
self,
model,
model_filename,
params_filename,
algo,
round_type,
data_urls,
......@@ -333,17 +342,16 @@ class TestPostTrainingQuantization(unittest.TestCase):
)
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
os.path.join(model_cache_folder, "MobileNetV1_infer"),
model_filename,
params_filename,
batch_size,
infer_iterations,
)
print(
"Start INT8 post training quantization for {0} on {1} images ...".format(
model, batch_nums * batch_size
)
)
self.generate_quantized_model(
os.path.join(model_cache_folder, "MobileNetV1_infer"),
model_filename,
params_filename,
quantizable_op_type,
batch_size,
algo,
......@@ -361,7 +369,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
)
)
(int8_throughput, int8_latency, int8_acc1) = self.run_program(
self.int8_model, batch_size, infer_iterations
self.int8_model,
model_filename,
params_filename,
batch_size,
infer_iterations,
)
print("---Post training quantization of {} method---".format(algo))
......@@ -403,6 +415,8 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
batch_nums = 3
self.run_test(
model,
'inference.pdmodel',
'inference.pdiparams',
algo,
round_type,
data_urls,
......@@ -435,6 +449,8 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
diff_threshold = 0.025
self.run_test(
model,
'inference.pdmodel',
'inference.pdiparams',
algo,
round_type,
data_urls,
......@@ -468,6 +484,8 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
batch_nums = 3
self.run_test(
model,
'inference.pdmodel',
'inference.pdiparams',
algo,
round_type,
data_urls,
......@@ -501,6 +519,8 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
diff_threshold = 0.05
self.run_test(
model,
'inference.pdmodel',
'inference.pdiparams',
algo,
round_type,
data_urls,
......@@ -535,6 +555,8 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
batch_nums = 3
self.run_test(
model,
'inference.pdmodel',
'inference.pdiparams',
algo,
round_type,
data_urls,
......
......@@ -12,24 +12,22 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import functools
import os
import random
import sys
import time
import paddle
import random
import unittest
import functools
import contextlib
import numpy as np
import paddle.fluid as fluid
from PIL import Image, ImageEnhance
from paddle.fluid.contrib.slim.quantization import (
PostTrainingQuantizationProgram,
)
from PIL import Image
from test_post_training_quantization_mobilenetv1 import (
TestPostTrainingQuantization,
)
import paddle
from paddle.static.quantization import PostTrainingQuantizationProgram
paddle.enable_static()
random.seed(0)
......@@ -55,7 +53,7 @@ def resize_short(img, target_size):
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
if center is True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
......@@ -115,15 +113,27 @@ def val(data_dir=DATA_DIR):
class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization):
def run_program(self, model_path, batch_size, infer_iterations):
def run_program(
self,
model_path,
model_filename,
params_filename,
batch_size,
infer_iterations,
):
image_shape = [3, 224, 224]
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
[
infer_program,
feed_dict,
fetch_targets,
] = fluid.io.load_inference_model(model_path, exe)
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename=model_filename,
params_filename=params_filename,
)
val_reader = paddle.batch(val(), batch_size)
iterations = infer_iterations
test_info = []
......@@ -162,7 +172,12 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization):
infer_program,
feed_dict,
fetch_targets,
] = fluid.io.load_inference_model(model_path, exe)
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename=model_filename,
params_filename=params_filename,
)
return (
throughput,
latency,
......@@ -193,9 +208,8 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization):
)
sys.exit(-1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_reader = val()
same_scale_tensor_list = [
['batch_norm_3.tmp_2#/#1', 'batch_norm_4.tmp_2#*#1'],
......@@ -231,6 +245,8 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization):
def run_test(
self,
model,
model_filename,
params_filename,
algo,
round_type,
data_urls,
......@@ -244,7 +260,6 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization):
):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
model_cache_folder = self.download_data(data_urls, data_md5s, model)
......@@ -262,14 +277,12 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization):
fetch_targets,
) = self.run_program(
os.path.join(model_cache_folder, "model"),
model_filename,
params_filename,
batch_size,
infer_iterations,
)
print(
"Start INT8 post training quantization for {0} on {1} images ...".format(
model, sample_iterations * batch_size
)
)
self.generate_quantized_model(
infer_program,
quantizable_op_type,
......@@ -289,7 +302,11 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization):
)
)
(int8_throughput, int8_latency, int8_acc1, _, _, _) = self.run_program(
self.int8_model, batch_size, infer_iterations
self.int8_model,
model_filename,
params_filename,
batch_size,
infer_iterations,
)
print("---Post training quantization of {} method---".format(algo))
......@@ -317,9 +334,9 @@ class TestPostTrainingProgramAbsMaxForResnet50(
algo = "abs_max"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model_combined.tar.gz'
]
data_md5s = ['4a5194524823d9b76da6e738e1367881']
data_md5s = ['db212fd4e9edc83381aef4533107e60c']
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -327,6 +344,8 @@ class TestPostTrainingProgramAbsMaxForResnet50(
diff_threshold = 0.025
self.run_test(
model,
'model.pdmodel',
'model.pdiparams',
algo,
round_type,
data_urls,
......
......@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
from test_post_training_quantization_mobilenetv1 import (
TestPostTrainingQuantization,
)
import paddle
paddle.enable_static()
......@@ -28,9 +29,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
algo = "min_max"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model_combined.tar.gz'
]
data_md5s = ['4a5194524823d9b76da6e738e1367881']
data_md5s = ['db212fd4e9edc83381aef4533107e60c']
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -38,6 +39,8 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
diff_threshold = 0.025
self.run_test(
model,
'model.pdmodel',
'model.pdiparams',
algo,
round_type,
data_urls,
......@@ -56,9 +59,9 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
algo = "min_max"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model_combined.tar.gz'
]
data_md5s = ['4a5194524823d9b76da6e738e1367881']
data_md5s = ['db212fd4e9edc83381aef4533107e60c']
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -67,6 +70,8 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
onnx_format = True
self.run_test(
model,
'model.pdmodel',
'model.pdiparams',
algo,
round_type,
data_urls,
......
......@@ -11,19 +11,17 @@
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import time
import sys
import random
import math
import functools
import contextlib
import sys
import time
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.dataset.common import download
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.static.quantization import PostTrainingQuantization
paddle.enable_static()
......@@ -77,13 +75,13 @@ class TestPostTrainingQuantization(unittest.TestCase):
def run_program(self, model_path, batch_size, infer_iterations):
print("test model path:" + model_path)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
[
infer_program,
feed_dict,
fetch_targets,
] = fluid.io.load_inference_model(
] = paddle.static.load_inference_model(
model_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams',
......@@ -137,9 +135,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_data_loader=False,
):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_reader = paddle.dataset.mnist.train()
def val_data_generator():
......
......@@ -13,12 +13,13 @@
# limitations under the license.
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
import paddle
from paddle.fluid.framework import IrGraph
from paddle.framework import core
from paddle.static.quantization import Quant2Int8MkldnnPass
paddle.enable_static()
......@@ -28,8 +29,8 @@ class TestQuant2Int8MkldnnPassMul(unittest.TestCase):
return "mul"
def setUp(self):
self.scope = fluid.Scope()
self.place = fluid.CPUPlace()
self.scope = paddle.static.global_scope()
self.place = paddle.CPUPlace()
self.dtype = np.float32
self.use_mkldnn = True
......@@ -67,8 +68,8 @@ class TestQuant2Int8MkldnnPassMul(unittest.TestCase):
)
def test_dequantize_op_weights(self):
program = fluid.Program()
with fluid.program_guard(program):
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.prepare_program_mul(program)
graph = IrGraph(core.Graph(program.desc), for_test=True)
......@@ -131,8 +132,8 @@ class TestQuant2Int8MkldnnPassMatmulV2(TestQuant2Int8MkldnnPassMul):
class TestQuant2Int8MkldnnPassConv2D(unittest.TestCase):
def setUp(self):
self.scope = fluid.Scope()
self.place = fluid.CPUPlace()
self.scope = paddle.static.global_scope()
self.place = paddle.CPUPlace()
self.dtype = np.float32
self.use_cudnn = False
self.use_mkldnn = True
......@@ -218,8 +219,8 @@ class TestQuant2Int8MkldnnPassConv2D(unittest.TestCase):
self.assertTrue(op.op().attr("fuse_activation") == "relu")
def test_quant_update_activation(self):
program = fluid.Program()
with fluid.program_guard(program):
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.prepare_program_conv2d(program)
graph = IrGraph(core.Graph(program.desc), for_test=True)
graph = self.remove_fuse_activation_attribute(graph)
......@@ -239,8 +240,8 @@ class TestQuant2Int8MkldnnPassConv2D(unittest.TestCase):
return "nearest_interp"
def setUp(self):
self.scope = fluid.Scope()
self.place = fluid.CPUPlace()
self.scope = paddle.static.global_scope()
self.place = paddle.CPUPlace()
self.dtype = np.float32
self.use_cudnn = False
self.use_mkldnn = True
......@@ -352,8 +353,8 @@ class TestQuant2Int8MkldnnPassConv2D(unittest.TestCase):
self.assertTrue(op.op().attr("mkldnn_data_type") == "int8")
def test_quant_update_activation(self):
program = fluid.Program()
with fluid.program_guard(program):
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.prepare_program(program)
graph = IrGraph(core.Graph(program.desc), for_test=True)
quant2_int8_mkldnn_pass = Quant2Int8MkldnnPass(
......
......@@ -13,40 +13,46 @@
# limitations under the license.
import os
import unittest
import random
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantInt8MkldnnPass
from paddle.fluid import core
from paddle.framework import core
from paddle.static.quantization import (
QuantInt8MkldnnPass,
QuantizationFreezePass,
QuantizationTransformPass,
)
paddle.enable_static()
os.environ["CPU_NUM"] = "1"
def conv_net(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
conv_out_1 = paddle.static.nn.conv2d(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu",
act='relu',
)
conv_pool_1 = paddle.nn.functional.max_pool2d(
conv_out_1, kernel_size=2, stride=2
)
conv_pool_1 = paddle.static.nn.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
conv_out_2 = paddle.static.nn.conv2d(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu",
num_filters=20,
act='relu',
)
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
conv_pool_2 = paddle.nn.functional.max_pool2d(
conv_out_2, kernel_size=2, stride=2
)
prediction = paddle.static.nn.fc(conv_pool_2, size=10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
)
......@@ -77,17 +83,17 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
def build_program(self, main, startup, is_test, seed):
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32'
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main, startup):
img = paddle.static.data(
name='image', shape=[-1, 1, 28, 28], dtype='float32'
)
label = fluid.layers.data(
name='label', shape=[1], dtype='int64'
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
loss = conv_net(img, label)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt = paddle.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
return [img, label], loss
......@@ -103,19 +109,19 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
random.seed(0)
np.random.seed(0)
main = fluid.Program()
startup = fluid.Program()
test_program = fluid.Program()
main = paddle.static.Program()
startup = paddle.static.Program()
test_program = paddle.static.Program()
feeds, loss = self.build_program(main, startup, False, seed)
self.build_program(test_program, startup, True, seed)
test_program = test_program.clone(for_test=True)
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
with fluid.scope_guard(scope):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
scope = paddle.static.global_scope()
with paddle.static.scope_guard(scope):
exe.run(startup)
# Apply the QuantizationTransformPass
transform_pass = QuantizationTransformPass(
......@@ -133,12 +139,12 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
)
transform_pass.apply(test_graph)
build_strategy = fluid.BuildStrategy()
build_strategy = paddle.static.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy
)
binary = paddle.static.CompiledProgram(
main_graph.graph
).with_data_parallel(loss_name=loss.name, build_strategy=build_strategy)
quantized_test_program = test_graph.to_program()
iters = 5
batch_size = 8
......@@ -150,10 +156,10 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size
)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
feeder = paddle.fluid.DataFeeder(feed_list=feeds, place=place)
# Training the model to get the weights value
with fluid.scope_guard(scope):
with paddle.static.scope_guard(scope):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(
......@@ -204,12 +210,12 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
+ activation_quant_type
+ '_'
+ weight_quant_type,
np.sum(w_mkldnn),
np.sum(mul_w_mkldnn),
)
)
def test_mkldnn_graph_cpu_static(self):
with fluid.unique_name.guard():
with paddle.utils.unique_name.guard():
self.mkldnn_based_freeze_graph(
False,
seed=2,
......
......@@ -13,19 +13,23 @@
# limitations under the license.
import os
import unittest
import random
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core
from paddle.framework import core
from paddle.static.quantization import (
AddQuantDequantPass,
ConvertToInt8Pass,
QuantizationFreezePass,
QuantizationTransformPass,
QuantizationTransformPassV2,
TransformForMobilePass,
)
paddle.enable_static()
......@@ -34,11 +38,13 @@ os.environ["CPU_NUM"] = "1"
def linear_fc(num):
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
data = paddle.static.data(
name='image', shape=[-1, 1, 32, 32], dtype='float32'
)
label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
hidden = data
for _ in range(num):
hidden = fluid.layers.fc(hidden, size=128, act='relu')
hidden = paddle.static.nn.fc(hidden, size=128, activation='relu')
loss = paddle.nn.functional.cross_entropy(
input=hidden, label=label, reduction='none', use_softmax=False
)
......@@ -61,34 +67,30 @@ def residual_block(num, quant_skip_pattern=None):
)
return paddle.static.nn.batch_norm(input=tmp, act=act)
data = fluid.layers.data(
data = paddle.static.data(
name='image',
shape=[1, 1, 32, 32],
dtype='float32',
append_batch_size=False,
)
label = fluid.layers.data(
name='label', shape=[1, 1], dtype='int64', append_batch_size=False
)
label = paddle.static.data(name='label', shape=[1, 1], dtype='int64')
hidden = data
for _ in range(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = paddle.nn.functional.relu(paddle.add(x=conv, y=short))
matmul_weight = paddle.create_parameter(
hidden = paddle.add(x=conv, y=short)
hidden = paddle.nn.functional.relu(hidden)
matmul_weight = paddle.static.create_parameter(
shape=[1, 16, 32, 32], dtype='float32'
)
hidden = paddle.matmul(hidden, matmul_weight, True, True)
if quant_skip_pattern:
with fluid.name_scope(quant_skip_pattern):
with paddle.static.name_scope(quant_skip_pattern):
pool = paddle.nn.functional.avg_pool2d(
x=hidden, kernel_size=2, stride=2
hidden, kernel_size=2, stride=2
)
else:
pool = paddle.nn.functional.avg_pool2d(
x=hidden, kernel_size=2, stride=2
)
fc = fluid.layers.fc(input=pool, size=10)
pool = paddle.nn.functional.avg_pool2d(hidden, kernel_size=2, stride=2)
fc = paddle.static.nn.fc(pool, size=10)
loss = paddle.nn.functional.cross_entropy(
input=fc, label=label, reduction='none', use_softmax=False
)
......@@ -97,28 +99,29 @@ def residual_block(num, quant_skip_pattern=None):
def conv_net(img, label, quant_skip_pattern):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
conv_out_1 = paddle.static.nn.conv2d(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
pool_type='max',
act="relu",
act='relu',
)
conv_pool_1 = paddle.nn.functional.max_pool2d(
conv_out_1, kernel_size=2, stride=2
)
conv_pool_1 = paddle.static.nn.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
conv_out_2 = paddle.static.nn.conv2d(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
pool_type='avg',
act="relu",
num_filters=20,
act='relu',
)
conv_pool_2 = paddle.nn.functional.avg_pool2d(
conv_out_2, kernel_size=2, stride=2
)
hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu')
with fluid.name_scope(quant_skip_pattern):
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
hidden = paddle.static.nn.fc(conv_pool_2, size=100, activation='relu')
with paddle.static.name_scope(quant_skip_pattern):
prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
)
......@@ -164,16 +167,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
def linear_fc_quant(
self, activation_quant_type, weight_quantize_type, for_ci=True
):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt = paddle.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
place = fluid.CPUPlace()
place = paddle.CPUPlace()
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
scope=paddle.static.global_scope(),
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type,
......@@ -217,16 +220,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
quantizable_op_type,
for_ci=True,
):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt = paddle.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
place = fluid.CPUPlace()
place = paddle.CPUPlace()
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
scope=paddle.static.global_scope(),
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type,
......@@ -289,36 +292,36 @@ class TestQuantizationFreezePass(unittest.TestCase):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32'
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main, startup):
img = paddle.static.data(
name='image', shape=[-1, 1, 28, 28], dtype='float32'
)
label = fluid.layers.data(
name='label', shape=[1], dtype='int64'
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
loss = conv_net(img, label, quant_skip_pattern)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt = paddle.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
return [img, label], loss
random.seed(0)
np.random.seed(0)
main = fluid.Program()
startup = fluid.Program()
test_program = fluid.Program()
main = paddle.static.Program()
startup = paddle.static.Program()
test_program = paddle.static.Program()
feeds, loss = build_program(main, startup, False)
build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True)
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
with fluid.scope_guard(scope):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
exe = paddle.static.Executor(place)
scope = paddle.static.global_scope()
with paddle.static.scope_guard(scope):
exe.run(startup)
transform_pass = QuantizationTransformPass(
scope=scope,
......@@ -365,13 +368,13 @@ class TestQuantizationFreezePass(unittest.TestCase):
marked_nodes,
)
build_strategy = fluid.BuildStrategy()
build_strategy = paddle.static.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy
)
binary = paddle.static.CompiledProgram(
main_graph.graph
).with_data_parallel(loss_name=loss.name, build_strategy=build_strategy)
quantized_test_program = test_graph.to_program()
iters = 5
batch_size = 8
......@@ -383,8 +386,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size
)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
with fluid.scope_guard(scope):
feeder = paddle.fluid.DataFeeder(feed_list=feeds, place=place)
with paddle.static.scope_guard(scope):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(
......@@ -403,12 +406,12 @@ class TestQuantizationFreezePass(unittest.TestCase):
)
test_data = next(test_reader())
with fluid.program_guard(quantized_test_program):
with paddle.static.program_guard(quantized_test_program):
w_var = fluid.framework._get_var(
'conv2d_1.w_0.quantized', quantized_test_program
)
# Testing
with fluid.scope_guard(scope):
with paddle.static.scope_guard(scope):
test_loss1, w_quant = exe.run(
program=quantized_test_program,
feed=feeder.feed(test_data),
......@@ -439,7 +442,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
)
server_program = test_graph.to_program()
with fluid.scope_guard(scope):
with paddle.static.scope_guard(scope):
(test_loss2,) = exe.run(
program=server_program,
feed=feeder.feed(test_data),
......@@ -511,25 +514,32 @@ class TestQuantizationFreezePass(unittest.TestCase):
)
server_program_int8 = test_graph.to_program()
# Save the 8-bit parameter and model file.
with fluid.scope_guard(scope):
fluid.io.save_inference_model(
with paddle.static.scope_guard(scope):
feed_list = ['image', 'label']
feed_vars = [
server_program_int8.global_block().var(name)
for name in feed_list
]
paddle.static.save_inference_model(
'server_int8'
+ dev_name
+ activation_quant_type
+ '_'
+ weight_quant_type,
['image', 'label'],
+ weight_quant_type
+ '/model',
feed_vars,
[loss],
exe,
server_program_int8,
program=server_program_int8,
)
# Test whether the 8-bit parameter and model file can be loaded successfully.
[infer, feed, fetch] = fluid.io.load_inference_model(
[infer, feed, fetch] = paddle.static.load_inference_model(
'server_int8'
+ dev_name
+ activation_quant_type
+ '_'
+ weight_quant_type,
+ weight_quant_type
+ '/model',
exe,
)
# Check the loaded 8-bit weight.
......@@ -576,22 +586,27 @@ class TestQuantizationFreezePass(unittest.TestCase):
)
mobile_program = test_graph.to_program()
with fluid.scope_guard(scope):
fluid.io.save_inference_model(
with paddle.static.scope_guard(scope):
feed_list = ['image', 'label']
feed_vars = [
mobile_program.global_block().var(name) for name in feed_list
]
paddle.static.save_inference_model(
'mobile_int8'
+ dev_name
+ activation_quant_type
+ '_'
+ weight_quant_type,
['image', 'label'],
+ weight_quant_type
+ '/model',
feed_vars,
[loss],
exe,
mobile_program,
program=mobile_program,
)
def test_freeze_graph_cuda_dynamic(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
if core.is_compiled_with_cuda():
with paddle.utils.unique_name.guard():
self.freeze_graph(
True,
seed=1,
......@@ -599,7 +614,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
weight_quant_type='abs_max',
for_ci=True,
)
with fluid.unique_name.guard():
with paddle.utils.unique_name.guard():
self.freeze_graph(
True,
seed=1,
......@@ -609,7 +624,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
)
def test_freeze_graph_cpu_dynamic(self):
with fluid.unique_name.guard():
with paddle.utils.unique_name.guard():
self.freeze_graph(
False,
seed=2,
......@@ -626,8 +641,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
)
def test_freeze_graph_cuda_static(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
if core.is_compiled_with_cuda():
with paddle.utils.unique_name.guard():
self.freeze_graph(
True,
seed=1,
......@@ -674,7 +689,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
)
def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard():
with paddle.utils.unique_name.guard():
self.freeze_graph(
False,
seed=2,
......@@ -720,48 +735,50 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None):
)
return paddle.static.nn.batch_norm(input=tmp, act=act)
data1 = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
data2 = fluid.layers.data(
name='matmul_input', shape=[16, 32, 32], dtype='float32'
data1 = paddle.static.data(
name='image', shape=[-1, 1, 32, 32], dtype='float32'
)
data2 = paddle.static.data(
name='matmul_input', shape=[-1, 16, 32, 32], dtype='float32'
)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
hidden = data1
for _ in range(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = paddle.nn.functional.relu(paddle.add(x=conv, y=short))
hidden = paddle.add(x=conv, y=short)
hidden = paddle.nn.functional.relu(hidden)
hidden = paddle.matmul(hidden, data2, True, True)
if isinstance(quant_skip_pattern, str):
with fluid.name_scope(quant_skip_pattern):
with paddle.static.name_scope(quant_skip_pattern):
pool1 = paddle.nn.functional.avg_pool2d(
x=hidden, kernel_size=2, stride=2
hidden, kernel_size=2, stride=2
)
pool2 = paddle.nn.functional.max_pool2d(
x=hidden, kernel_size=2, stride=2
hidden, kernel_size=2, stride=2
)
pool_add = paddle.nn.functional.relu(paddle.add(x=pool1, y=pool2))
pool_add = paddle.add(pool1, pool2)
pool_add = paddle.nn.functional.relu(pool_add)
elif isinstance(quant_skip_pattern, list):
assert (
len(quant_skip_pattern) > 1
), 'test config error: the len of quant_skip_pattern list should be greater than 1.'
with fluid.name_scope(quant_skip_pattern[0]):
with paddle.static.name_scope(quant_skip_pattern[0]):
pool1 = paddle.nn.functional.avg_pool2d(
x=hidden, kernel_size=2, stride=2
hidden, kernel_size=2, stride=2
)
pool2 = paddle.nn.functional.max_pool2d(
x=hidden, kernel_size=2, stride=2
hidden, kernel_size=2, stride=2
)
with fluid.name_scope(quant_skip_pattern[1]):
pool_add = paddle.nn.functional.relu(paddle.add(x=pool1, y=pool2))
with paddle.static.name_scope(quant_skip_pattern[1]):
pool_add = paddle.add(pool1, pool2)
pool_add = paddle.nn.functional.relu(pool_add)
else:
pool1 = paddle.nn.functional.avg_pool2d(
x=hidden, kernel_size=2, stride=2
)
pool2 = paddle.nn.functional.max_pool2d(
x=hidden, kernel_size=2, stride=2
)
pool_add = paddle.nn.functional.relu(paddle.add(x=pool1, y=pool2))
fc = fluid.layers.fc(input=pool_add, size=10)
pool1 = paddle.nn.functional.avg_pool2d(hidden, kernel_size=2, stride=2)
pool2 = paddle.nn.functional.max_pool2d(hidden, kernel_size=2, stride=2)
pool_add = paddle.add(pool1, pool2)
pool_add = paddle.nn.functional.relu(pool_add)
fc = paddle.static.nn.fc(pool_add, size=10)
loss = paddle.nn.functional.cross_entropy(
input=fc, label=label, reduction='none', use_softmax=False
)
......@@ -814,16 +831,16 @@ class TestAddQuantDequantPass(unittest.TestCase):
def residual_block_quant(
self, quantizable_op_type, skip_pattern=None, for_ci=True
):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
loss = quant_dequant_residual_block(2, skip_pattern)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt = paddle.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
place = fluid.CPUPlace()
place = paddle.CPUPlace()
graph = IrGraph(core.Graph(main.desc), for_test=False)
add_quant_dequant_pass = AddQuantDequantPass(
scope=fluid.global_scope(),
scope=paddle.static.global_scope(),
place=place,
skip_pattern=skip_pattern,
quantizable_op_type=quantizable_op_type,
......@@ -904,16 +921,16 @@ class TestQuantizationTransformPassV2(unittest.TestCase):
def linear_fc_quant(
self, activation_quant_type, weight_quantize_type, for_ci=True
):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt = paddle.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
place = fluid.CPUPlace()
place = paddle.CPUPlace()
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPassV2(
scope=fluid.global_scope(),
scope=paddle.static.global_scope(),
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type,
......@@ -952,16 +969,16 @@ class TestQuantizationTransformPassV2(unittest.TestCase):
quantizable_op_type,
for_ci=True,
):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt = paddle.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
place = fluid.CPUPlace()
place = paddle.CPUPlace()
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
scope=paddle.static.global_scope(),
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type,
......
......@@ -13,19 +13,22 @@
# limitations under the license.
import os
import unittest
import random
import numpy as np
import tempfile
import paddle.fluid as fluid
import unittest
import numpy as np
import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core
from paddle.framework import core
from paddle.static.quantization import (
AddQuantDequantPass,
OutScaleForInferencePass,
OutScaleForTrainingPass,
QuantizationFreezePass,
QuantizationTransformPass,
)
paddle.enable_static()
......@@ -34,27 +37,27 @@ os.environ["CPU_NUM"] = "1"
def conv_net(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
conv_out_1 = paddle.static.nn.conv2d(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
pool_type='max',
act="relu",
act='relu',
)
conv_pool_1 = paddle.nn.functional.max_pool2d(
conv_out_1, kernel_size=2, stride=2
)
conv_pool_1 = paddle.static.nn.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
conv_out_2 = paddle.static.nn.conv2d(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
pool_type='avg',
act="relu",
num_filters=20,
act='relu',
)
hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
conv_pool_2 = paddle.nn.functional.avg_pool2d(
conv_out_2, kernel_size=2, stride=2
)
hidden = paddle.static.nn.fc(conv_pool_2, size=100, activation='relu')
prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
)
......@@ -74,36 +77,36 @@ class TestQuantizationScalePass(unittest.TestCase):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32'
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main, startup):
img = paddle.static.data(
name='image', shape=[-1, 1, 28, 28], dtype='float32'
)
label = fluid.layers.data(
name='label', shape=[1], dtype='int64'
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
loss = conv_net(img, label)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.0001)
opt = paddle.optimizer.Adam(learning_rate=0.0001)
opt.minimize(loss)
return [img, label], loss
random.seed(0)
np.random.seed(0)
main = fluid.Program()
startup = fluid.Program()
test_program = fluid.Program()
main = paddle.static.Program()
startup = paddle.static.Program()
test_program = paddle.static.Program()
feeds, loss = build_program(main, startup, False)
build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True)
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
with fluid.scope_guard(scope):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
exe = paddle.static.Executor(place)
scope = paddle.static.global_scope()
with paddle.static.scope_guard(scope):
exe.run(startup)
transform_pass = QuantizationTransformPass(
......@@ -135,13 +138,13 @@ class TestQuantizationScalePass(unittest.TestCase):
marked_nodes.add(op)
test_graph.draw('.', 'test_scale' + dev_name, marked_nodes)
build_strategy = fluid.BuildStrategy()
build_strategy = paddle.static.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy
)
binary = paddle.static.CompiledProgram(
main_graph.graph
).with_data_parallel(loss_name=loss.name, build_strategy=build_strategy)
iters = 5
batch_size = 8
......@@ -149,8 +152,8 @@ class TestQuantizationScalePass(unittest.TestCase):
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
batch_size=batch_size,
)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
with fluid.scope_guard(scope):
feeder = paddle.fluid.DataFeeder(feed_list=feeds, place=place)
with paddle.static.scope_guard(scope):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(
......@@ -184,20 +187,24 @@ class TestQuantizationScalePass(unittest.TestCase):
with open(mapping_table_path, 'w') as f:
f.write(str(server_program))
with fluid.scope_guard(scope):
fluid.io.save_inference_model(
with paddle.static.scope_guard(scope):
feed_list = ['image', 'label']
feed_vars = [
server_program.global_block().var(name) for name in feed_list
]
paddle.static.save_inference_model(
save_path,
['image', 'label'],
feed_vars,
[loss],
exe,
server_program,
program=server_program,
clip_extra=True,
)
tempdir.cleanup()
def test_quant_scale_cuda(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
if core.is_compiled_with_cuda():
with paddle.utils.unique_name.guard():
self.quantization_scale(
True,
seed=1,
......@@ -207,7 +214,7 @@ class TestQuantizationScalePass(unittest.TestCase):
)
def test_quant_scale_cpu(self):
with fluid.unique_name.guard():
with paddle.utils.unique_name.guard():
self.quantization_scale(
False,
seed=2,
......
......@@ -12,23 +12,24 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import os
import unittest
import json
import os
import random
import numpy as np
import tempfile
import paddle.fluid as fluid
import unittest
import numpy as np
import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
import paddle.nn.functional as F
from paddle.framework import LayerHelper, core
from paddle.static.quantization import (
AddQuantDequantPass,
OutScaleForInferencePass,
OutScaleForTrainingPass,
QuantizationFreezePass,
QuantizationTransformPass,
)
paddle.enable_static()
......@@ -37,27 +38,27 @@ os.environ["CPU_NUM"] = "1"
def conv_net(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
conv_out_1 = paddle.static.nn.conv2d(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
pool_type='max',
act="relu",
act='relu',
)
conv_pool_1 = paddle.nn.functional.max_pool2d(
conv_out_1, kernel_size=2, stride=2
)
conv_pool_1 = paddle.static.nn.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
conv_out_2 = paddle.static.nn.conv2d(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
pool_type='avg',
act="relu",
num_filters=20,
act='relu',
)
hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
conv_pool_2 = paddle.nn.functional.avg_pool2d(
conv_out_2, kernel_size=2, stride=2
)
hidden = paddle.static.nn.fc(conv_pool_2, size=100, activation='relu')
prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(
input=prediction, label=label, reduction='none', use_softmax=False
)
......@@ -69,15 +70,17 @@ def pact(x, name=None):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = fluid.ParamAttr(
u_param_attr = paddle.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
initializer=paddle.nn.initializer.Constant(value=init_thres),
regularizer=paddle.regularizer.L2Decay(0.0001),
learning_rate=1,
)
u_param = helper.create_parameter(attr=u_param_attr, shape=[1], dtype=dtype)
x = paddle.subtract(x, F.relu(paddle.subtract(x, u_param)))
x = paddle.add(x, F.relu(paddle.subtract(-u_param, x)))
x = paddle.subtract(
x, paddle.nn.functional.relu(paddle.subtract(x, u_param))
)
x = paddle.add(x, paddle.nn.functional.relu(paddle.subtract(-u_param, x)))
return x
......@@ -98,23 +101,23 @@ class TestUserDefinedQuantization(unittest.TestCase):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32'
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main, startup):
img = paddle.static.data(
name='image', shape=[-1, 1, 28, 28], dtype='float32'
)
img.stop_gradient = False
label = fluid.layers.data(
name='label', shape=[1], dtype='int64'
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
loss = conv_net(img, label)
if not is_test:
opt = fluid.optimizer.SGD(learning_rate=0.0001)
opt = paddle.optimizer.SGD(learning_rate=0.0001)
opt.minimize(loss)
return [img, label], loss
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
return paddle.optimizer.Momentum(0.0001, 0.9)
def load_dict(mapping_table_path):
with open(mapping_table_path, 'r') as file:
......@@ -131,19 +134,19 @@ class TestUserDefinedQuantization(unittest.TestCase):
tempdir = tempfile.TemporaryDirectory()
mapping_table_path = os.path.join(tempdir.name, 'inference')
main = fluid.Program()
startup = fluid.Program()
test_program = fluid.Program()
main = paddle.static.Program()
startup = paddle.static.Program()
test_program = paddle.static.Program()
feeds, loss = build_program(main, startup, False)
build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True)
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
with fluid.scope_guard(scope):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
exe = paddle.static.Executor(place)
scope = paddle.static.global_scope()
with paddle.static.scope_guard(scope):
exe.run(startup)
train_transform_pass = QuantizationTransformPass(
scope=scope,
......@@ -183,13 +186,13 @@ class TestUserDefinedQuantization(unittest.TestCase):
dev_name = '_gpu' if use_cuda else '_cpu'
build_strategy = fluid.BuildStrategy()
build_strategy = paddle.static.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy
)
binary = paddle.static.CompiledProgram(
main_graph.graph
).with_data_parallel(loss_name=loss.name, build_strategy=build_strategy)
iters = 5
batch_size = 8
......@@ -197,8 +200,8 @@ class TestUserDefinedQuantization(unittest.TestCase):
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
batch_size=batch_size,
)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
with fluid.scope_guard(scope):
feeder = paddle.fluid.DataFeeder(feed_list=feeds, place=place)
with paddle.static.scope_guard(scope):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(
......@@ -223,8 +226,8 @@ class TestUserDefinedQuantization(unittest.TestCase):
tempdir.cleanup()
def test_act_preprocess_cuda(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
if core.is_compiled_with_cuda():
with paddle.utils.unique_name.guard():
self.quantization_scale(
True,
seed=1,
......@@ -235,7 +238,7 @@ class TestUserDefinedQuantization(unittest.TestCase):
)
def test_act_preprocess_cpu(self):
with fluid.unique_name.guard():
with paddle.utils.unique_name.guard():
self.quantization_scale(
False,
seed=2,
......@@ -246,8 +249,8 @@ class TestUserDefinedQuantization(unittest.TestCase):
)
def test_weight_preprocess_cuda(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
if core.is_compiled_with_cuda():
with paddle.utils.unique_name.guard():
self.quantization_scale(
True,
seed=1,
......@@ -258,7 +261,7 @@ class TestUserDefinedQuantization(unittest.TestCase):
)
def test_weight_preprocess_cpu(self):
with fluid.unique_name.guard():
with paddle.utils.unique_name.guard():
self.quantization_scale(
False,
seed=2,
......@@ -269,8 +272,8 @@ class TestUserDefinedQuantization(unittest.TestCase):
)
def test_act_quantize_cuda(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
if core.is_compiled_with_cuda():
with paddle.utils.unique_name.guard():
self.quantization_scale(
True,
seed=1,
......@@ -281,7 +284,7 @@ class TestUserDefinedQuantization(unittest.TestCase):
)
def test_act_quantize_cpu(self):
with fluid.unique_name.guard():
with paddle.utils.unique_name.guard():
self.quantization_scale(
False,
seed=2,
......@@ -292,8 +295,8 @@ class TestUserDefinedQuantization(unittest.TestCase):
)
def test_weight_quantize_cuda(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
if core.is_compiled_with_cuda():
with paddle.utils.unique_name.guard():
self.quantization_scale(
True,
seed=1,
......@@ -304,7 +307,7 @@ class TestUserDefinedQuantization(unittest.TestCase):
)
def test_weight_quantize_cpu(self):
with fluid.unique_name.guard():
with paddle.utils.unique_name.guard():
self.quantization_scale(
False,
seed=2,
......
......@@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import time
import unittest
import numpy as np
from paddle.dataset.common import download, DATA_HOME
from paddle.fluid.contrib.slim.quantization import WeightQuantization
import paddle
from paddle.dataset.common import DATA_HOME, download
from paddle.static.quantization import WeightQuantization
paddle.enable_static()
......@@ -73,6 +75,8 @@ class TestWeightQuantization(unittest.TestCase):
def quantize_to_int(
self,
model_name,
model_filename,
params_filename,
model_data_url,
model_data_md5,
weight_bits,
......@@ -93,7 +97,11 @@ class TestWeightQuantization(unittest.TestCase):
model_name + "_wq_" + str(weight_bits) + "_" + timestamp,
)
weight_quant = WeightQuantization(model_dir=load_model_dir)
weight_quant = WeightQuantization(
model_dir=load_model_dir,
model_filename=model_filename,
params_filename=params_filename,
)
weight_quant.quantize_weight_to_int(
save_model_dir=save_model_dir,
weight_bits=weight_bits,
......@@ -183,7 +191,7 @@ class TestWeightQuantization(unittest.TestCase):
inference_program,
feed_target_names,
fetch_targets,
] = paddle.fluid.io.load_inference_model(
] = paddle.static.load_inference_model(
model_dir,
exe,
model_filename=model_filename,
......@@ -193,10 +201,10 @@ class TestWeightQuantization(unittest.TestCase):
if is_fp16_model:
for var in inference_program.list_vars():
if (
(var.type == paddle.fluid.core.VarDesc.VarType.RAW)
(var.type == paddle.framework.core.VarDesc.VarType.RAW)
or (not var.persistable)
or (var.name in ['feed', 'fetch'])
or (var.dtype != paddle.fluid.core.VarDesc.VarType.FP16)
or (var.dtype != paddle.framework.core.VarDesc.VarType.FP16)
):
continue
tensor = _load_variable_data(scope, var.name)
......@@ -228,9 +236,11 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
generate_test_model = True
threshold_rate = 0.0
self.quantize_to_int(
self.nocomb_model_name,
self.nocomb_model_data_url,
self.nocomb_model_data_md5,
self.comb_model_name,
'__model__',
'__params__',
self.comb_model_data_url,
self.comb_model_data_md5,
weight_bits,
quantizable_op_type,
weight_quantize_type,
......@@ -245,9 +255,11 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
generate_test_model = True
threshold_rate = 0.0
self.quantize_to_int(
self.nocomb_model_name,
self.nocomb_model_data_url,
self.nocomb_model_data_md5,
self.comb_model_name,
'__model__',
'__params__',
self.comb_model_data_url,
self.comb_model_data_md5,
weight_bits,
quantizable_op_type,
weight_quantize_type,
......@@ -262,9 +274,11 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
generate_test_model = False
threshold_rate = 0
self.quantize_to_int(
self.nocomb_model_name,
self.nocomb_model_data_url,
self.nocomb_model_data_md5,
self.comb_model_name,
'__model__',
'__params__',
self.comb_model_data_url,
self.comb_model_data_md5,
weight_bits,
quantizable_op_type,
weight_quantize_type,
......@@ -279,9 +293,11 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
generate_test_model = False
threshold_rate = 1e-9
self.quantize_to_int(
self.nocomb_model_name,
self.nocomb_model_data_url,
self.nocomb_model_data_md5,
self.comb_model_name,
'__model__',
'__params__',
self.comb_model_data_url,
self.comb_model_data_md5,
weight_bits,
quantizable_op_type,
weight_quantize_type,
......@@ -300,17 +316,6 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
params_filename,
)
def test_mobilenetv1_fp16_nocombined(self):
model_filename = None
params_filename = None
self.convert_to_fp16(
self.nocomb_model_name,
self.nocomb_model_data_url,
self.nocomb_model_data_md5,
model_filename,
params_filename,
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,9 +13,10 @@
# limitations under the License.
import sys
import numpy as np
from ....framework import IrNode
from ....framework import Operator
from ...fluid.framework import IrNode, Operator
_weight_supported_quantizable_op_type = [
'conv2d',
......@@ -158,7 +159,6 @@ _op_real_in_out_name = {
"reshape": [["X"], ["Out"]],
"reshape2": [["X"], ["Out"]],
"transpose2": [["X"], ["Out"]],
"bilinear_interp": [["X"], ["Out"]],
"nearest_interp": [["X"], ["Out"]],
"trilinear_interp": [["X"], ["Out"]],
"slice": [["Input"], ["Out"]],
......@@ -185,7 +185,6 @@ _op_real_in_out_name = {
"flatten": [["X"], ["Out"]],
"flatten2": [["X"], ["Out"]],
"unsqueeze2": [["X"], ["Out"]],
"unsqueeze2": [["X"], ["Out"]],
"flatten_contiguous_range": [["X"], ["Out"]],
"split": [["X"], ["Out"]],
"squeeze2": [["X"], ["Out"]],
......
......@@ -338,10 +338,6 @@ packages=['paddle',
'paddle.fluid.layers',
'paddle.fluid.dataloader',
'paddle.fluid.contrib',
'paddle.fluid.contrib.quantize',
'paddle.fluid.contrib.slim',
'paddle.fluid.contrib.slim.quantization',
'paddle.fluid.contrib.slim.quantization.imperative',
'paddle.fluid.contrib.extend_optimizer',
'paddle.fluid.contrib.mixed_precision',
'paddle.fluid.contrib.mixed_precision.bf16',
......@@ -405,6 +401,9 @@ packages=['paddle',
'paddle.static',
'paddle.static.nn',
'paddle.static.amp',
'paddle.static.quantization',
'paddle.quantization',
'paddle.quantization.imperative',
'paddle.tensor',
'paddle.onnx',
'paddle.autograd',
......
......@@ -1209,10 +1209,6 @@ def get_setup_parameters():
'paddle.fluid.layers',
'paddle.fluid.dataloader',
'paddle.fluid.contrib',
'paddle.fluid.contrib.quantize',
'paddle.fluid.contrib.slim',
'paddle.fluid.contrib.slim.quantization',
'paddle.fluid.contrib.slim.quantization.imperative',
'paddle.fluid.contrib.extend_optimizer',
'paddle.fluid.contrib.mixed_precision',
'paddle.fluid.contrib.mixed_precision.bf16',
......@@ -1276,6 +1272,9 @@ def get_setup_parameters():
'paddle.static',
'paddle.static.nn',
'paddle.static.amp',
'paddle.static.quantization',
'paddle.quantization',
'paddle.quantization.imperative',
'paddle.tensor',
'paddle.onnx',
'paddle.autograd',
......
......@@ -486,7 +486,7 @@ def get_filenames(full_test=False):
'''
global whl_error
import paddle # noqa: F401
import paddle.fluid.contrib.slim.quantization # noqa: F401
import paddle.static.quantization # noqa: F401
whl_error = []
if full_test:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册