未验证 提交 1d197f6c 编写于 作者: C cc 提交者: GitHub

[dgraph qat] Refine calculating output scale of dygraph qat (#31710)

* Refine calculating output scale of dygraph qat, test=develop
上级 420527f0
...@@ -25,12 +25,7 @@ from paddle.fluid.executor import Executor ...@@ -25,12 +25,7 @@ from paddle.fluid.executor import Executor
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant from paddle.fluid.initializer import Constant
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn import Linear, Conv2D, Conv2DTranspose, MaxPool2D, MaxPool1D
from paddle.nn import BatchNorm1D, BatchNorm2D, BatchNorm3D, SyncBatchNorm
from paddle.fluid.dygraph.nn import BatchNorm, Pool2D
from paddle.fluid.io import load_inference_model, save_inference_model from paddle.fluid.io import load_inference_model, save_inference_model
from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6
from paddle.nn.layer.activation import Tanh, Softmax, PReLU, Swish
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from . import quant_nn from . import quant_nn
from .. import quantization_pass from .. import quantization_pass
...@@ -62,14 +57,10 @@ class ImperativeQuantAware(object): ...@@ -62,14 +57,10 @@ class ImperativeQuantAware(object):
The constructor for ImperativeQuantAware. The constructor for ImperativeQuantAware.
Args: Args:
quantizable_layer_type(list[str]): List the type of layers that quantizable_layer_type(list[str | layer]): List the type of
will be quantized. Default is ['Conv2D', 'Linear']. layers that will be quantized. Default is ['Conv2D', 'Linear'].
The quantizable_op_type in QuantizationFreezePass and
ConvertToInt8Pass must be the same as this.
weight_quantize_type(str): quantization type for weights, weight_quantize_type(str): quantization type for weights,
which supports 'abs_max' now. The 'moving_average_abs_max' which supports 'abs_max' and 'channel_wise_abs_max'.
usually is not used for weights, since weights are fixed
once the model is well trained.
activation_quantize_type(str): quantization type for activations, activation_quantize_type(str): quantization type for activations,
which supports 'abs_max' and 'moving_average_abs_max' now. which supports 'abs_max' and 'moving_average_abs_max' now.
If using 'abs_max' mode, the quantization scale will be If using 'abs_max' mode, the quantization scale will be
...@@ -77,8 +68,8 @@ class ImperativeQuantAware(object): ...@@ -77,8 +68,8 @@ class ImperativeQuantAware(object):
period. If using 'moving_average_abs_max', the static period. If using 'moving_average_abs_max', the static
quantization scale will be calculated during training and quantization scale will be calculated during training and
used in inference. used in inference.
weight_bits(int): quantization bit number for weights, weight_bits(int): quantization bit number for weights, whereas
whereas the bias is not quantized. the bias is not quantized.
activation_bits(int): quantization bit number for activations. activation_bits(int): quantization bit number for activations.
moving_rate(float): the parameter for 'moving_average_abs_max' moving_rate(float): the parameter for 'moving_average_abs_max'
quantization. quantization.
...@@ -260,8 +251,8 @@ class ImperativeQuantizeInputs(object): ...@@ -260,8 +251,8 @@ class ImperativeQuantizeInputs(object):
super(ImperativeQuantizeInputs, self).__init__() super(ImperativeQuantizeInputs, self).__init__()
self._quantizable_layer_type = tuple( self._quantizable_layer_type = tuple(
utils._quant_layers_map[layer] utils.supported_quant_layers_map[layer]
if layer in utils._quant_layers_map else layer if layer in utils.supported_quant_layers_map else layer
for layer in quantizable_layer_type) for layer in quantizable_layer_type)
for layer in self._quantizable_layer_type: for layer in self._quantizable_layer_type:
assert not isinstance(layer, str), \ assert not isinstance(layer, str), \
...@@ -338,7 +329,7 @@ class ImperativeQuantizeInputs(object): ...@@ -338,7 +329,7 @@ class ImperativeQuantizeInputs(object):
def _get_quantized_layer(self, layer): def _get_quantized_layer(self, layer):
quant_layer_name = None quant_layer_name = None
for key, value in utils._quant_layers_map.items(): for key, value in utils.supported_quant_layers_map.items():
if isinstance(layer, value): if isinstance(layer, value):
quant_layer_name = 'Quantized' + key quant_layer_name = 'Quantized' + key
break break
...@@ -364,10 +355,6 @@ class ImperativeCalcOutputScale(object): ...@@ -364,10 +355,6 @@ class ImperativeCalcOutputScale(object):
""" """
super(ImperativeCalcOutputScale, self).__init__() super(ImperativeCalcOutputScale, self).__init__()
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._out_scale_layer_type_list = (
BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, LeakyReLU,
Linear, PReLU, Pool2D, MaxPool1D, MaxPool2D, ReLU, ReLU6, Sigmoid,
Softmax, SyncBatchNorm, Tanh, Swish)
self._register_hook_handle_list = [] self._register_hook_handle_list = []
self._out_scale_dict = collections.OrderedDict() self._out_scale_dict = collections.OrderedDict()
...@@ -378,7 +365,7 @@ class ImperativeCalcOutputScale(object): ...@@ -378,7 +365,7 @@ class ImperativeCalcOutputScale(object):
Args: Args:
model(fluid.dygraph.Layer): The target model which would be model(fluid.dygraph.Layer): The target model which would be
calculate the output quantization scale. calculate the output quantization scale.
Returns: Returns:
None None
...@@ -387,10 +374,10 @@ class ImperativeCalcOutputScale(object): ...@@ -387,10 +374,10 @@ class ImperativeCalcOutputScale(object):
"The model must be the instance of dygraph.Layer." "The model must be the instance of dygraph.Layer."
for _, layer in model.named_sublayers(): for _, layer in model.named_sublayers():
if self._is_target_layer(layer): if self._is_target_layer(layer):
self._add_new_parameters(layer) self._init_scale_params(layer)
forward_post_hook_handle = layer.register_forward_post_hook( hook_handle = layer.register_forward_post_hook(
self._forward_post_hook) self._calc_output_scale_hook)
self._register_hook_handle_list.append(forward_post_hook_handle) self._register_hook_handle_list.append(hook_handle)
def save_quantized_model(self, layer, path, input_spec=None, **config): def save_quantized_model(self, layer, path, input_spec=None, **config):
""" """
...@@ -398,63 +385,64 @@ class ImperativeCalcOutputScale(object): ...@@ -398,63 +385,64 @@ class ImperativeCalcOutputScale(object):
Args: Args:
layer (Layer): The Layer to be saved. layer (Layer): The Layer to be saved.
path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``. path (str): The path prefix to save model. The format is
input_spec (list[InputSpec|Tensor], optional): Describes the input of the saved model's forward ``dirname/file_prefix`` or ``file_prefix``.
method, which can be described by InputSpec or example Tensor. If None, all input variables of input_spec (list[InputSpec|Tensor], optional): Describes the input
the original Layer's forward method would be the inputs of the saved model. Default None. of the saved model's forward method, which can be described by
**configs (dict, optional): Other save configuration options for compatibility. We do not InputSpec or example Tensor. If None, all input variables of
recommend using these configurations, they may be removed in the future. If not necessary, the original Layer's forward method would be the inputs of
DO NOT use them. Default None. the saved model. Default None.
**configs (dict, optional): Other save configuration options for
compatibility. We do not recommend using these configurations,
they may be removed in the future. If not necessary, DO NOT use
them. Default None.
The following options are currently supported: The following options are currently supported:
(1) output_spec (list[Tensor]): Selects the output targets of the saved model. (1) output_spec (list[Tensor]): Selects the output targets of
By default, all return variables of original Layer's forward method are kept as the the saved model. By default, all return variables of original
output of the saved model. If the provided ``output_spec`` list is not all output variables, Layer's forward method are kept as the output of the saved model.
the saved model will be pruned according to the given ``output_spec`` list. If the provided ``output_spec`` list is not all output variables,
the saved model will be pruned according to the given
``output_spec`` list.
Returns: Returns:
None None
""" """
assert isinstance( assert isinstance(layer, dygraph.Layer), \
layer, dygraph.Layer), "model must be the instance of dygraph.Layer" "The model must be the instance of dygraph.Layer."
self._layer = layer
is_dynamic_mode = False # remove handles and collect output scales
with dygraph.guard(): with dygraph.guard():
self._layer.eval() layer.eval()
if self._register_hook_handle_list is not None: for handle in self._register_hook_handle_list:
for handle in self._register_hook_handle_list: handle.remove()
handle.remove() for _, sub_layer in layer.named_sublayers():
if self._out_scale_dict: if self._is_target_layer(sub_layer):
for key in self._out_scale_dict: if hasattr(sub_layer, "layer_name"):
self._out_scale_dict[key] = float(self._out_scale_dict[key] layer_name = sub_layer.layer_name
.numpy()) else:
else:
for _, sub_layer in self._layer.named_sublayers():
if self._is_target_layer(sub_layer):
layer_name = sub_layer.full_name() layer_name = sub_layer.full_name()
if hasattr(sub_layer, "layer_name"): if hasattr(sub_layer, "_quant_out_scale"):
layer_name = sub_layer.layer_name self._out_scale_dict[layer_name] = float(
if hasattr(sub_layer, "_quant_out_scale"): sub_layer._quant_out_scale)
self._out_scale_dict[layer_name] = float(
sub_layer._quant_out_scale)
# save the quantized model that doesn't have output scales
paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config)
# load static model
is_dynamic_mode = False
if paddle.in_dynamic_mode(): if paddle.in_dynamic_mode():
is_dynamic_mode = True is_dynamic_mode = True
paddle.enable_static() paddle.enable_static()
paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config) place = core.CUDAPlace(0) if core.is_compiled_with_cuda() \
else core.CPUPlace()
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = Executor(place) exe = Executor(place)
file_prefix = os.path.basename(path)
dirname = os.path.dirname(path) dirname = os.path.dirname(path)
model_filename = file_prefix + INFER_MODEL_SUFFIX basename = os.path.basename(path)
params_filename = file_prefix + INFER_PARAMS_SUFFIX model_filename = basename + INFER_MODEL_SUFFIX
params_filename = basename + INFER_PARAMS_SUFFIX
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names, fetch_targets] = (
load_inference_model( load_inference_model(
dirname=dirname, dirname=dirname,
...@@ -462,14 +450,15 @@ class ImperativeCalcOutputScale(object): ...@@ -462,14 +450,15 @@ class ImperativeCalcOutputScale(object):
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename)) params_filename=params_filename))
# set output scales to the static model
check_behind_op = False check_behind_op = False
op_count = 0 op_count = 0
ops_list = [key for key, _ in self._out_scale_dict.items()] ops_list = [key for key, _ in self._out_scale_dict.items()]
if len(ops_list) == 0: if len(ops_list) == 0:
warnings.warn( warnings.warn(
"Warning: No Layer of the model while to be saved contains the out_threshold attribute, " "Warning: No Layer of the model while to be saved contains "
"so the generated inference model would not contain the out_threshold." "the out_threshold attribute, so the generated inference "
) "model would not contain the out_threshold.")
else: else:
# Because the Layer in dygraph may correspond to multiple ops # Because the Layer in dygraph may correspond to multiple ops
# in static program after being saved. To ensure correctness, # in static program after being saved. To ensure correctness,
...@@ -481,11 +470,12 @@ class ImperativeCalcOutputScale(object): ...@@ -481,11 +470,12 @@ class ImperativeCalcOutputScale(object):
forward_op = None forward_op = None
for block in inference_program.blocks: for block in inference_program.blocks:
for op in block.ops: for op in block.ops:
if op.type in utils._op_real_in_out_name: if op.type in utils.op_real_in_out_name:
if op_count > len(ops_list): if op_count > len(ops_list):
warnings.warn( warnings.warn(
"The number of Layer which has out_threshold attribute should be bigger than the op in inference model" "The number of Layer which has "
) "out_threshold attribute should be bigger than "
"the op in inference model")
break break
if check_behind_op: if check_behind_op:
check_behind_op = False check_behind_op = False
...@@ -525,7 +515,7 @@ class ImperativeCalcOutputScale(object): ...@@ -525,7 +515,7 @@ class ImperativeCalcOutputScale(object):
self._out_scale_dict[ops_list[op_count]]) self._out_scale_dict[ops_list[op_count]])
op_count += 1 op_count += 1
# Save the processed program. # save the final quantized model that has output scales
save_inference_model( save_inference_model(
dirname=dirname, dirname=dirname,
feeded_var_names=feed_target_names, feeded_var_names=feed_target_names,
...@@ -539,41 +529,40 @@ class ImperativeCalcOutputScale(object): ...@@ -539,41 +529,40 @@ class ImperativeCalcOutputScale(object):
paddle.disable_static() paddle.disable_static()
def _is_target_layer(self, layer): def _is_target_layer(self, layer):
return isinstance(layer, self._out_scale_layer_type_list) \ return isinstance(layer, utils.out_scale_layers_list) \
or 'quantized_' in layer.full_name() or 'quantized_' in layer.full_name()
# When inferenc model is saved, the logic in hook would not be executed def _init_scale_params(self, layer, name=None):
# in program translation, so that some parameters can not created in """
# __init__, which would cause the model to fail to save. Therefore, the Init the scale params for calculating output scales and save them in the
# parameters creation in the hook is advanced to be exected outside the hook. target layer.
def _add_new_parameters(self, layer, name=None): After the users define the dygraph model, the hooks for calculating output
scales will not execute immediately. If the users load the checkpoint now,
the scale params have not been created, so them cann't be loaded.
Therefore, define the scale params in the beginning.
"""
def _create_param(in_layer, first_name, last_name, dtype):
prefix = '{}.{}'.format(first_name, last_name) \
if first_name else 'outscale.{}'.format(last_name)
attr = ParamAttr(
name=unique_name.generate(prefix),
initializer=Constant(1),
trainable=False)
param = in_layer.create_parameter(shape=[1], attr=attr, dtype=dtype)
return param
dtype = layer._dtype if layer._dtype is not None else "float32" dtype = layer._dtype if layer._dtype is not None else "float32"
if dtype not in ["float32", "float64"]: if dtype not in ["float32", "float64"]:
return return
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
scale_name = unique_name.generate(scale_prefix) layer._quant_out_scale = _create_param(layer, name, "scale", dtype)
scale_attr = ParamAttr(
name=scale_name, initializer=Constant(1), trainable=False)
layer._quant_out_scale = layer.create_parameter(
shape=[1], attr=scale_attr, dtype=dtype)
layer._quant_out_scale.stop_gradient = True layer._quant_out_scale.stop_gradient = True
state_prefix = "{}.state".format(name) if name else 'outscale.state' layer._quant_out_state = _create_param(layer, name, "state", dtype)
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
layer._quant_out_state = layer.create_parameter(
shape=[1], attr=state_attr, dtype=dtype)
layer._quant_out_state.stop_gradient = True layer._quant_out_state.stop_gradient = True
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum' layer._quant_out_accum = _create_param(layer, name, "accum", dtype)
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
layer._quant_out_accum = layer.create_parameter(
shape=[1], attr=accum_attr, dtype=dtype)
layer._quant_out_accum.stop_gradient = True layer._quant_out_accum.stop_gradient = True
# Judge whether the op in program matches the Layer in dynamic model # Judge whether the op in program matches the Layer in dynamic model
...@@ -598,20 +587,18 @@ class ImperativeCalcOutputScale(object): ...@@ -598,20 +587,18 @@ class ImperativeCalcOutputScale(object):
op_type = op_type.replace('relu', 're_lu') op_type = op_type.replace('relu', 're_lu')
return op_type in layer_name return op_type in layer_name
def _forward_post_hook(self, layer, input, output): def _calc_output_scale_hook(self, layer, input, output):
assert isinstance( """
output, (core.VarBase, framework.Variable) Create the MovingAverageAbsMaxScale layer for the target layer if needed.
), "Multiple outputs are not currently supported in ImperativeOutScale." Execute MovingAverageAbsMaxScale layer to calculate the output scale.
if output.dtype not in [ """
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64 assert isinstance(output, (core.VarBase, framework.Variable)), \
]: "Multiple outputs are not currently supported in ImperativeOutScale."
return
if not hasattr(layer, "_out_scale"): fp_types = [core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64]
self._out_scale = quant_nn.MovingAverageAbsMaxScale( if output.dtype in fp_types:
layer, output.name, self._moving_rate, output.dtype) if not hasattr(layer, "_out_scale"):
scale_out = self._out_scale(output) self._out_scale = quant_nn.MovingAverageAbsMaxScale(
if hasattr(layer, 'layer_name'): layer, output.name, self._moving_rate, output.dtype)
layer_name = layer.layer_name # TODO (jc): consider the ops that have several outputs
else: self._out_scale(output)
layer_name = layer.full_name()
self._out_scale_dict[layer_name] = scale_out
...@@ -499,6 +499,10 @@ class QuantizedNoweightLayer(layers.Layer): ...@@ -499,6 +499,10 @@ class QuantizedNoweightLayer(layers.Layer):
def forward(self, input): def forward(self, input):
quant_input = self._fake_quant_input(input) quant_input = self._fake_quant_input(input)
# TODO (jc): support ops that have several inputs
if isinstance(input, list):
assert len(input) == 1, \
"The QuantizedNoweightLayer should only have one input."
return self._layer.forward(quant_input) return self._layer.forward(quant_input)
......
...@@ -12,12 +12,9 @@ ...@@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.nn import Linear, Conv2D import paddle
from paddle.fluid.dygraph.nn import Pool2D
from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6
from paddle.nn.layer.activation import Tanh, Softmax, PReLU, Swish
_op_real_in_out_name = { op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]], "conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]], "depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"pool2d": [["X"], ["Out"]], "pool2d": [["X"], ["Out"]],
...@@ -33,14 +30,30 @@ _op_real_in_out_name = { ...@@ -33,14 +30,30 @@ _op_real_in_out_name = {
"swish": [["X"], ["Out"]], "swish": [["X"], ["Out"]],
} }
_quant_layers_map = { supported_quant_layers_map = {
'Conv2D': Conv2D, 'Conv2D': paddle.nn.Conv2D,
'Linear': Linear, 'Linear': paddle.nn.Linear,
'Pool2D': Pool2D, 'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D,
'ReLU': ReLU, 'AdaptiveMaxPool2D': paddle.nn.AdaptiveMaxPool2D,
'LeakyReLU': LeakyReLU, 'AvgPool2D': paddle.nn.AvgPool2D,
'ReLU6': ReLU6, 'MaxPool2D': paddle.nn.MaxPool2D,
'Softmax': Softmax, 'Hardswish': paddle.nn.Hardswish,
'Tanh': Tanh, 'LeakyReLU': paddle.nn.LeakyReLU,
'Swish': Swish 'PReLU': paddle.nn.PReLU,
'ReLU': paddle.nn.ReLU,
'ReLU6': paddle.nn.ReLU6,
'Sigmoid': paddle.nn.Sigmoid,
'Softmax': paddle.nn.Softmax,
'Swish': paddle.nn.Swish,
'Tanh': paddle.nn.Tanh,
'Hardswish': paddle.nn.Hardswish,
'BatchNorm': paddle.nn.BatchNorm,
'GroupNorm': paddle.nn.GroupNorm,
'LayerNorm': paddle.nn.LayerNorm,
} }
out_scale_layers_list = (
paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.MaxPool2D,
paddle.nn.BatchNorm, paddle.nn.BatchNorm2D, paddle.nn.SyncBatchNorm,
paddle.nn.LeakyReLU, paddle.nn.PReLU, paddle.nn.ReLU, paddle.nn.ReLU6,
paddle.nn.Sigmoid, paddle.nn.Softmax, paddle.nn.Tanh, paddle.nn.Swish)
...@@ -191,8 +191,8 @@ class TestImperativeAddQuantDequant(unittest.TestCase): ...@@ -191,8 +191,8 @@ class TestImperativeAddQuantDequant(unittest.TestCase):
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max', activation_quantize_type='moving_average_abs_max',
quantizable_layer_type=[ quantizable_layer_type=[
'Conv2D', 'Linear', 'ReLU', 'Pool2D', 'LeakyReLU', 'ReLU6', 'Conv2D', 'Linear', 'ReLU', 'LeakyReLU', 'ReLU6', 'Tanh',
'Tanh', 'Swish' 'Swish'
]) ])
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册