未验证 提交 1d197f6c 编写于 作者: C cc 提交者: GitHub

[dgraph qat] Refine calculating output scale of dygraph qat (#31710)

* Refine calculating output scale of dygraph qat, test=develop
上级 420527f0
......@@ -25,12 +25,7 @@ from paddle.fluid.executor import Executor
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn import Linear, Conv2D, Conv2DTranspose, MaxPool2D, MaxPool1D
from paddle.nn import BatchNorm1D, BatchNorm2D, BatchNorm3D, SyncBatchNorm
from paddle.fluid.dygraph.nn import BatchNorm, Pool2D
from paddle.fluid.io import load_inference_model, save_inference_model
from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6
from paddle.nn.layer.activation import Tanh, Softmax, PReLU, Swish
from paddle.fluid.log_helper import get_logger
from . import quant_nn
from .. import quantization_pass
......@@ -62,14 +57,10 @@ class ImperativeQuantAware(object):
The constructor for ImperativeQuantAware.
Args:
quantizable_layer_type(list[str]): List the type of layers that
will be quantized. Default is ['Conv2D', 'Linear'].
The quantizable_op_type in QuantizationFreezePass and
ConvertToInt8Pass must be the same as this.
quantizable_layer_type(list[str | layer]): List the type of
layers that will be quantized. Default is ['Conv2D', 'Linear'].
weight_quantize_type(str): quantization type for weights,
which supports 'abs_max' now. The 'moving_average_abs_max'
usually is not used for weights, since weights are fixed
once the model is well trained.
which supports 'abs_max' and 'channel_wise_abs_max'.
activation_quantize_type(str): quantization type for activations,
which supports 'abs_max' and 'moving_average_abs_max' now.
If using 'abs_max' mode, the quantization scale will be
......@@ -77,8 +68,8 @@ class ImperativeQuantAware(object):
period. If using 'moving_average_abs_max', the static
quantization scale will be calculated during training and
used in inference.
weight_bits(int): quantization bit number for weights,
whereas the bias is not quantized.
weight_bits(int): quantization bit number for weights, whereas
the bias is not quantized.
activation_bits(int): quantization bit number for activations.
moving_rate(float): the parameter for 'moving_average_abs_max'
quantization.
......@@ -260,8 +251,8 @@ class ImperativeQuantizeInputs(object):
super(ImperativeQuantizeInputs, self).__init__()
self._quantizable_layer_type = tuple(
utils._quant_layers_map[layer]
if layer in utils._quant_layers_map else layer
utils.supported_quant_layers_map[layer]
if layer in utils.supported_quant_layers_map else layer
for layer in quantizable_layer_type)
for layer in self._quantizable_layer_type:
assert not isinstance(layer, str), \
......@@ -338,7 +329,7 @@ class ImperativeQuantizeInputs(object):
def _get_quantized_layer(self, layer):
quant_layer_name = None
for key, value in utils._quant_layers_map.items():
for key, value in utils.supported_quant_layers_map.items():
if isinstance(layer, value):
quant_layer_name = 'Quantized' + key
break
......@@ -364,10 +355,6 @@ class ImperativeCalcOutputScale(object):
"""
super(ImperativeCalcOutputScale, self).__init__()
self._moving_rate = moving_rate
self._out_scale_layer_type_list = (
BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, LeakyReLU,
Linear, PReLU, Pool2D, MaxPool1D, MaxPool2D, ReLU, ReLU6, Sigmoid,
Softmax, SyncBatchNorm, Tanh, Swish)
self._register_hook_handle_list = []
self._out_scale_dict = collections.OrderedDict()
......@@ -378,7 +365,7 @@ class ImperativeCalcOutputScale(object):
Args:
model(fluid.dygraph.Layer): The target model which would be
calculate the output quantization scale.
calculate the output quantization scale.
Returns:
None
......@@ -387,10 +374,10 @@ class ImperativeCalcOutputScale(object):
"The model must be the instance of dygraph.Layer."
for _, layer in model.named_sublayers():
if self._is_target_layer(layer):
self._add_new_parameters(layer)
forward_post_hook_handle = layer.register_forward_post_hook(
self._forward_post_hook)
self._register_hook_handle_list.append(forward_post_hook_handle)
self._init_scale_params(layer)
hook_handle = layer.register_forward_post_hook(
self._calc_output_scale_hook)
self._register_hook_handle_list.append(hook_handle)
def save_quantized_model(self, layer, path, input_spec=None, **config):
"""
......@@ -398,63 +385,64 @@ class ImperativeCalcOutputScale(object):
Args:
layer (Layer): The Layer to be saved.
path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor], optional): Describes the input of the saved model's forward
method, which can be described by InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of the saved model. Default None.
**configs (dict, optional): Other save configuration options for compatibility. We do not
recommend using these configurations, they may be removed in the future. If not necessary,
DO NOT use them. Default None.
path (str): The path prefix to save model. The format is
``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor], optional): Describes the input
of the saved model's forward method, which can be described by
InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of
the saved model. Default None.
**configs (dict, optional): Other save configuration options for
compatibility. We do not recommend using these configurations,
they may be removed in the future. If not necessary, DO NOT use
them. Default None.
The following options are currently supported:
(1) output_spec (list[Tensor]): Selects the output targets of the saved model.
By default, all return variables of original Layer's forward method are kept as the
output of the saved model. If the provided ``output_spec`` list is not all output variables,
the saved model will be pruned according to the given ``output_spec`` list.
(1) output_spec (list[Tensor]): Selects the output targets of
the saved model. By default, all return variables of original
Layer's forward method are kept as the output of the saved model.
If the provided ``output_spec`` list is not all output variables,
the saved model will be pruned according to the given
``output_spec`` list.
Returns:
None
"""
assert isinstance(
layer, dygraph.Layer), "model must be the instance of dygraph.Layer"
self._layer = layer
is_dynamic_mode = False
assert isinstance(layer, dygraph.Layer), \
"The model must be the instance of dygraph.Layer."
# remove handles and collect output scales
with dygraph.guard():
self._layer.eval()
if self._register_hook_handle_list is not None:
for handle in self._register_hook_handle_list:
handle.remove()
if self._out_scale_dict:
for key in self._out_scale_dict:
self._out_scale_dict[key] = float(self._out_scale_dict[key]
.numpy())
else:
for _, sub_layer in self._layer.named_sublayers():
if self._is_target_layer(sub_layer):
layer.eval()
for handle in self._register_hook_handle_list:
handle.remove()
for _, sub_layer in layer.named_sublayers():
if self._is_target_layer(sub_layer):
if hasattr(sub_layer, "layer_name"):
layer_name = sub_layer.layer_name
else:
layer_name = sub_layer.full_name()
if hasattr(sub_layer, "layer_name"):
layer_name = sub_layer.layer_name
if hasattr(sub_layer, "_quant_out_scale"):
self._out_scale_dict[layer_name] = float(
sub_layer._quant_out_scale)
if hasattr(sub_layer, "_quant_out_scale"):
self._out_scale_dict[layer_name] = float(
sub_layer._quant_out_scale)
# save the quantized model that doesn't have output scales
paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config)
# load static model
is_dynamic_mode = False
if paddle.in_dynamic_mode():
is_dynamic_mode = True
paddle.enable_static()
paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config)
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
place = core.CUDAPlace(0) if core.is_compiled_with_cuda() \
else core.CPUPlace()
exe = Executor(place)
file_prefix = os.path.basename(path)
dirname = os.path.dirname(path)
model_filename = file_prefix + INFER_MODEL_SUFFIX
params_filename = file_prefix + INFER_PARAMS_SUFFIX
basename = os.path.basename(path)
model_filename = basename + INFER_MODEL_SUFFIX
params_filename = basename + INFER_PARAMS_SUFFIX
[inference_program, feed_target_names, fetch_targets] = (
load_inference_model(
dirname=dirname,
......@@ -462,14 +450,15 @@ class ImperativeCalcOutputScale(object):
model_filename=model_filename,
params_filename=params_filename))
# set output scales to the static model
check_behind_op = False
op_count = 0
ops_list = [key for key, _ in self._out_scale_dict.items()]
if len(ops_list) == 0:
warnings.warn(
"Warning: No Layer of the model while to be saved contains the out_threshold attribute, "
"so the generated inference model would not contain the out_threshold."
)
"Warning: No Layer of the model while to be saved contains "
"the out_threshold attribute, so the generated inference "
"model would not contain the out_threshold.")
else:
# Because the Layer in dygraph may correspond to multiple ops
# in static program after being saved. To ensure correctness,
......@@ -481,11 +470,12 @@ class ImperativeCalcOutputScale(object):
forward_op = None
for block in inference_program.blocks:
for op in block.ops:
if op.type in utils._op_real_in_out_name:
if op.type in utils.op_real_in_out_name:
if op_count > len(ops_list):
warnings.warn(
"The number of Layer which has out_threshold attribute should be bigger than the op in inference model"
)
"The number of Layer which has "
"out_threshold attribute should be bigger than "
"the op in inference model")
break
if check_behind_op:
check_behind_op = False
......@@ -525,7 +515,7 @@ class ImperativeCalcOutputScale(object):
self._out_scale_dict[ops_list[op_count]])
op_count += 1
# Save the processed program.
# save the final quantized model that has output scales
save_inference_model(
dirname=dirname,
feeded_var_names=feed_target_names,
......@@ -539,41 +529,40 @@ class ImperativeCalcOutputScale(object):
paddle.disable_static()
def _is_target_layer(self, layer):
return isinstance(layer, self._out_scale_layer_type_list) \
return isinstance(layer, utils.out_scale_layers_list) \
or 'quantized_' in layer.full_name()
# When inferenc model is saved, the logic in hook would not be executed
# in program translation, so that some parameters can not created in
# __init__, which would cause the model to fail to save. Therefore, the
# parameters creation in the hook is advanced to be exected outside the hook.
def _add_new_parameters(self, layer, name=None):
def _init_scale_params(self, layer, name=None):
"""
Init the scale params for calculating output scales and save them in the
target layer.
After the users define the dygraph model, the hooks for calculating output
scales will not execute immediately. If the users load the checkpoint now,
the scale params have not been created, so them cann't be loaded.
Therefore, define the scale params in the beginning.
"""
def _create_param(in_layer, first_name, last_name, dtype):
prefix = '{}.{}'.format(first_name, last_name) \
if first_name else 'outscale.{}'.format(last_name)
attr = ParamAttr(
name=unique_name.generate(prefix),
initializer=Constant(1),
trainable=False)
param = in_layer.create_parameter(shape=[1], attr=attr, dtype=dtype)
return param
dtype = layer._dtype if layer._dtype is not None else "float32"
if dtype not in ["float32", "float64"]:
return
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
scale_name = unique_name.generate(scale_prefix)
scale_attr = ParamAttr(
name=scale_name, initializer=Constant(1), trainable=False)
layer._quant_out_scale = layer.create_parameter(
shape=[1], attr=scale_attr, dtype=dtype)
layer._quant_out_scale = _create_param(layer, name, "scale", dtype)
layer._quant_out_scale.stop_gradient = True
state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
layer._quant_out_state = layer.create_parameter(
shape=[1], attr=state_attr, dtype=dtype)
layer._quant_out_state = _create_param(layer, name, "state", dtype)
layer._quant_out_state.stop_gradient = True
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
layer._quant_out_accum = layer.create_parameter(
shape=[1], attr=accum_attr, dtype=dtype)
layer._quant_out_accum = _create_param(layer, name, "accum", dtype)
layer._quant_out_accum.stop_gradient = True
# Judge whether the op in program matches the Layer in dynamic model
......@@ -598,20 +587,18 @@ class ImperativeCalcOutputScale(object):
op_type = op_type.replace('relu', 're_lu')
return op_type in layer_name
def _forward_post_hook(self, layer, input, output):
assert isinstance(
output, (core.VarBase, framework.Variable)
), "Multiple outputs are not currently supported in ImperativeOutScale."
if output.dtype not in [
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64
]:
return
if not hasattr(layer, "_out_scale"):
self._out_scale = quant_nn.MovingAverageAbsMaxScale(
layer, output.name, self._moving_rate, output.dtype)
scale_out = self._out_scale(output)
if hasattr(layer, 'layer_name'):
layer_name = layer.layer_name
else:
layer_name = layer.full_name()
self._out_scale_dict[layer_name] = scale_out
def _calc_output_scale_hook(self, layer, input, output):
"""
Create the MovingAverageAbsMaxScale layer for the target layer if needed.
Execute MovingAverageAbsMaxScale layer to calculate the output scale.
"""
assert isinstance(output, (core.VarBase, framework.Variable)), \
"Multiple outputs are not currently supported in ImperativeOutScale."
fp_types = [core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64]
if output.dtype in fp_types:
if not hasattr(layer, "_out_scale"):
self._out_scale = quant_nn.MovingAverageAbsMaxScale(
layer, output.name, self._moving_rate, output.dtype)
# TODO (jc): consider the ops that have several outputs
self._out_scale(output)
......@@ -499,6 +499,10 @@ class QuantizedNoweightLayer(layers.Layer):
def forward(self, input):
quant_input = self._fake_quant_input(input)
# TODO (jc): support ops that have several inputs
if isinstance(input, list):
assert len(input) == 1, \
"The QuantizedNoweightLayer should only have one input."
return self._layer.forward(quant_input)
......
......@@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.nn import Linear, Conv2D
from paddle.fluid.dygraph.nn import Pool2D
from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6
from paddle.nn.layer.activation import Tanh, Softmax, PReLU, Swish
import paddle
_op_real_in_out_name = {
op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"pool2d": [["X"], ["Out"]],
......@@ -33,14 +30,30 @@ _op_real_in_out_name = {
"swish": [["X"], ["Out"]],
}
_quant_layers_map = {
'Conv2D': Conv2D,
'Linear': Linear,
'Pool2D': Pool2D,
'ReLU': ReLU,
'LeakyReLU': LeakyReLU,
'ReLU6': ReLU6,
'Softmax': Softmax,
'Tanh': Tanh,
'Swish': Swish
supported_quant_layers_map = {
'Conv2D': paddle.nn.Conv2D,
'Linear': paddle.nn.Linear,
'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D,
'AdaptiveMaxPool2D': paddle.nn.AdaptiveMaxPool2D,
'AvgPool2D': paddle.nn.AvgPool2D,
'MaxPool2D': paddle.nn.MaxPool2D,
'Hardswish': paddle.nn.Hardswish,
'LeakyReLU': paddle.nn.LeakyReLU,
'PReLU': paddle.nn.PReLU,
'ReLU': paddle.nn.ReLU,
'ReLU6': paddle.nn.ReLU6,
'Sigmoid': paddle.nn.Sigmoid,
'Softmax': paddle.nn.Softmax,
'Swish': paddle.nn.Swish,
'Tanh': paddle.nn.Tanh,
'Hardswish': paddle.nn.Hardswish,
'BatchNorm': paddle.nn.BatchNorm,
'GroupNorm': paddle.nn.GroupNorm,
'LayerNorm': paddle.nn.LayerNorm,
}
out_scale_layers_list = (
paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.MaxPool2D,
paddle.nn.BatchNorm, paddle.nn.BatchNorm2D, paddle.nn.SyncBatchNorm,
paddle.nn.LeakyReLU, paddle.nn.PReLU, paddle.nn.ReLU, paddle.nn.ReLU6,
paddle.nn.Sigmoid, paddle.nn.Softmax, paddle.nn.Tanh, paddle.nn.Swish)
......@@ -191,8 +191,8 @@ class TestImperativeAddQuantDequant(unittest.TestCase):
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max',
quantizable_layer_type=[
'Conv2D', 'Linear', 'ReLU', 'Pool2D', 'LeakyReLU', 'ReLU6',
'Tanh', 'Swish'
'Conv2D', 'Linear', 'ReLU', 'LeakyReLU', 'ReLU6', 'Tanh',
'Swish'
])
with fluid.dygraph.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册