未验证 提交 00c85a74 编写于 作者: C cc 提交者: GitHub

[Dygraph QAT] Save all scales to target ops and Move quant layers to paddle.nn.quant (#33871)

* Save all scales to target ops
* Move quant layers to paddle.nn.quant
上级 ea1a0d45
...@@ -14,9 +14,6 @@ ...@@ -14,9 +14,6 @@
from __future__ import print_function from __future__ import print_function
from . import quant_nn
from .quant_nn import *
from . import qat from . import qat
from .qat import * from .qat import *
...@@ -33,7 +30,6 @@ from . import ptq_registry ...@@ -33,7 +30,6 @@ from . import ptq_registry
from .ptq_registry import * from .ptq_registry import *
__all__ = [] __all__ = []
__all__ += quant_nn.__all__
__all__ += qat.__all__ __all__ += qat.__all__
__all__ += ptq.__all__ __all__ += ptq.__all__
__all__ += ptq_config.__all__ __all__ += ptq_config.__all__
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import warnings import warnings
import paddle import paddle
import paddle.nn.quant.quant_layers as quant_layers
from paddle.fluid import dygraph, core, framework, unique_name from paddle.fluid import dygraph, core, framework, unique_name
from paddle.fluid.executor import Executor, global_scope from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
...@@ -28,7 +29,6 @@ from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX ...@@ -28,7 +29,6 @@ from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.fluid.io import load_inference_model, save_inference_model from paddle.fluid.io import load_inference_model, save_inference_model
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from .. import quantization_pass from .. import quantization_pass
from . import quant_nn
from . import utils from . import utils
__all__ = ['ImperativeQuantAware'] __all__ = ['ImperativeQuantAware']
...@@ -39,7 +39,7 @@ _logger = get_logger( ...@@ -39,7 +39,7 @@ _logger = get_logger(
class ImperativeQuantAware(object): class ImperativeQuantAware(object):
""" """
Applying quantization aware training (QAT) to dgraph model. Applying quantization aware training (QAT) to the dgraph model.
""" """
def __init__(self, def __init__(self,
...@@ -329,12 +329,12 @@ class ImperativeQuantizeInputs(object): ...@@ -329,12 +329,12 @@ class ImperativeQuantizeInputs(object):
"The layer %s is unsupported to be quantized." \ "The layer %s is unsupported to be quantized." \
% layer.full_name() % layer.full_name()
return quant_nn.__dict__[quant_layer_name](layer, **self._kwargs) return quant_layers.__dict__[quant_layer_name](layer, **self._kwargs)
class ImperativeQuantizeOutputs(object): class ImperativeQuantizeOutputs(object):
""" """
Calculate the output scales for some layers. Calculate the output scales for target layers.
""" """
def __init__(self, moving_rate=0.9): def __init__(self, moving_rate=0.9):
...@@ -371,11 +371,11 @@ class ImperativeQuantizeOutputs(object): ...@@ -371,11 +371,11 @@ class ImperativeQuantizeOutputs(object):
utils.find_parent_layer_and_sub_name(model, cur_name) utils.find_parent_layer_and_sub_name(model, cur_name)
if isinstance(cur_layer, tuple(utils.fake_quant_output_layers)): if isinstance(cur_layer, tuple(utils.fake_quant_output_layers)):
cur_quant_layer = quant_nn.FakeQuantMAOutputScaleLayer( cur_quant_layer = quant_layers.FakeQuantMAOutputScaleLayer(
cur_layer, self._moving_rate) cur_layer, self._moving_rate)
else: else:
cur_quant_layer = quant_nn.MAOutputScaleLayer(cur_layer, cur_quant_layer = quant_layers.MAOutputScaleLayer(
self._moving_rate) cur_layer, self._moving_rate)
setattr(parent_layer, sub_name, cur_quant_layer) setattr(parent_layer, sub_name, cur_quant_layer)
...@@ -433,7 +433,7 @@ class ImperativeQuantizeOutputs(object): ...@@ -433,7 +433,7 @@ class ImperativeQuantizeOutputs(object):
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename)) params_filename=params_filename))
self._save_output_scale(infer_program, scope) self._gather_scales(infer_program, scope)
self._set_skip_quant_attr(infer_program) self._set_skip_quant_attr(infer_program)
...@@ -455,37 +455,80 @@ class ImperativeQuantizeOutputs(object): ...@@ -455,37 +455,80 @@ class ImperativeQuantizeOutputs(object):
""" """
flag = False flag = False
if isinstance(layer, dygraph.Layer): if isinstance(layer, dygraph.Layer):
# exclude fake_quant ops in quant_nn file # exclude fake_quant ops in quant_layers file
if utils.is_leaf_layer(layer) and \ if utils.is_leaf_layer(layer) and \
not isinstance(layer, tuple(utils.fake_quant_leaf_layers)): not isinstance(layer, tuple(utils.fake_quant_leaf_layers)):
flag = True flag = True
# consider QuantizedConv2D and QuantizedLinear ops
if isinstance(layer, tuple(utils.fake_quant_wrap_layers)): if isinstance(layer, tuple(utils.fake_quant_wrap_layers)):
flag = True flag = True
if isinstance(layer, paddle.nn.quant.FloatFunctionalLayer): if isinstance(layer, paddle.nn.quant.FloatFunctionalLayer):
flag = True flag = True
return flag return flag
def _save_output_scale(self, program, scope): def _gather_scales(self, program, scope):
""" """
Save all output scales to the corresponding ops in static Get all scales from fake ops, save them into the corresponding ops
inference program and delete 'moving_average_abs_max_scale' ops. and delete all moving_average_abs_max_scale ops.
""" """
def _gather_input_scale():
target_ops = []
skip_ops = utils.fake_quantize_dequantize_op_types + \
["moving_average_abs_max_scale"]
for block in program.blocks:
for op in block.ops:
if op.type not in skip_ops:
target_ops.append(op)
for op in target_ops:
for in_var_name in utils._get_op_input_var_names(op):
previous_op = utils.find_previous_op(op.block, in_var_name)
if previous_op is not None and \
("quantize_dequantize" in previous_op.type or \
previous_op.type == "moving_average_abs_max_scale"):
scale_name = previous_op.output('OutScale')[0]
in_scale = utils.load_variable_data(scope, scale_name)
in_scale = utils.fp_numpy_to_naive(in_scale)
argname, index = utils._get_input_name_index(
op, in_var_name)
op._set_attr(argname + str(index) + "_threshold",
in_scale)
def _gather_output_scale():
target_ops = []
for block in program.blocks: for block in program.blocks:
for op in block.ops: for op in block.ops:
if op.type == "moving_average_abs_max_scale": if op.type == "moving_average_abs_max_scale":
target_ops.append(op)
for op in target_ops:
in_var_name = op.input('X')[0] in_var_name = op.input('X')[0]
out_var_name = op.output('Out')[0] out_var_name = op.output('Out')[0]
out_scale_name = op.output('OutScale')[0] block = op.block
previous_op = utils.find_previous_op(block, in_var_name)
next_ops = utils.find_next_ops(block, out_var_name)
out_scale_name = op.output('OutScale')[0]
out_scale = utils.load_variable_data(scope, out_scale_name) out_scale = utils.load_variable_data(scope, out_scale_name)
previous_op = utils.find_previous_op(block, in_var_name) out_scale = utils.fp_numpy_to_naive(out_scale)
previous_op._set_attr("out_threshold", float(out_scale))
if previous_op.type != "feed":
argname, index = utils._get_output_name_index(previous_op,
in_var_name)
previous_op._set_attr(argname + str(index) + "_threshold",
out_scale)
previous_op._set_attr("out_threshold", out_scale)
next_ops = utils.find_next_ops(block, out_var_name)
for next_op in next_ops: for next_op in next_ops:
next_op._rename_input(out_var_name, in_var_name) next_op._rename_input(out_var_name, in_var_name)
_gather_input_scale()
_gather_output_scale()
def _set_skip_quant_attr(self, program): def _set_skip_quant_attr(self, program):
""" """
Label the skip quantized ops. Label the skip quantized ops.
......
...@@ -16,8 +16,12 @@ import math ...@@ -16,8 +16,12 @@ import math
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn.quant.quant_layers as quant_layers
from . import quant_nn from ..quantization_pass import _get_op_input_var_names
from ..quantization_pass import _get_op_output_var_names
from ..quantization_pass import _get_output_name_index
from ..quantization_pass import _get_input_name_index
layer_name_map = { layer_name_map = {
'Conv2D': paddle.nn.Conv2D, 'Conv2D': paddle.nn.Conv2D,
...@@ -54,13 +58,15 @@ fake_quant_output_layers = [ ...@@ -54,13 +58,15 @@ fake_quant_output_layers = [
] ]
fake_quant_leaf_layers = [ fake_quant_leaf_layers = [
quant_nn.FakeQuantAbsMax, quant_layers.FakeQuantAbsMax,
quant_nn.FakeQuantChannelWiseAbsMax, quant_layers.FakeQuantChannelWiseAbsMax,
quant_nn.FakeQuantMovingAverageAbsMax, quant_layers.FakeQuantMovingAverageAbsMax,
quant_nn.MovingAverageAbsMaxScale, quant_layers.MovingAverageAbsMaxScale,
] ]
fake_quant_wrap_layers = [quant_nn.QuantizedConv2D, quant_nn.QuantizedLinear] fake_quant_wrap_layers = [
quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear
]
# The weight format of these layers is Cin * Cout * H * W # The weight format of these layers is Cin * Cout * H * W
spec_channel_axis_layers = [paddle.nn.Conv2D, paddle.nn.Conv2DTranspose] spec_channel_axis_layers = [paddle.nn.Conv2D, paddle.nn.Conv2DTranspose]
...@@ -94,6 +100,7 @@ def find_previous_op(block, var_name): ...@@ -94,6 +100,7 @@ def find_previous_op(block, var_name):
for op in block.ops: for op in block.ops:
if var_name in op.output_arg_names: if var_name in op.output_arg_names:
return op return op
return None
def find_next_ops(block, var_name): def find_next_ops(block, var_name):
...@@ -244,3 +251,10 @@ def cal_kl_scaling_factor(hist, abs_max, bits): ...@@ -244,3 +251,10 @@ def cal_kl_scaling_factor(hist, abs_max, bits):
break break
min_kl_index = starting_iter min_kl_index = starting_iter
return (min_kl_index + 0.5) * bin_width return (min_kl_index + 0.5) * bin_width
def fp_numpy_to_naive(x_np):
if x_np.size == 1:
return float(x_np)
else:
return x_np.tolist()
...@@ -141,12 +141,21 @@ _channelwise_quant_axis1_ops = ['conv2d_transpose', 'mul'] ...@@ -141,12 +141,21 @@ _channelwise_quant_axis1_ops = ['conv2d_transpose', 'mul']
def _get_op_input_var_names(op): def _get_op_input_var_names(op):
""" """ """
Get the input var names of the op.
Args:
op(IrNode, Operator): the input op.
Returns:
input_var_names or None.
"""
assert isinstance(op, (IrNode, Operator)), \ assert isinstance(op, (IrNode, Operator)), \
"The input op should be IrNode or Operator." "The input op should be IrNode or Operator."
var_names = [] var_names = []
op_name = op.name() if isinstance(op, IrNode) \ op_name = op.name() if isinstance(op, IrNode) \
else op.type else op.type
if op_name not in _op_real_in_out_name:
return []
name_list = _op_real_in_out_name[op_name][0] name_list = _op_real_in_out_name[op_name][0]
for name in name_list: for name in name_list:
var_name = op.input(name) var_name = op.input(name)
...@@ -163,6 +172,9 @@ def _get_input_name_index(op, input_var_name): ...@@ -163,6 +172,9 @@ def _get_input_name_index(op, input_var_name):
"The input op should be IrNode or Operator." "The input op should be IrNode or Operator."
op_name = op.name() if isinstance(op, IrNode) \ op_name = op.name() if isinstance(op, IrNode) \
else op.type else op.type
if op_name not in _op_real_in_out_name:
return None
res = None res = None
for argname in _op_real_in_out_name[op_name][0]: for argname in _op_real_in_out_name[op_name][0]:
var_names = op.input(argname) var_names = op.input(argname)
...@@ -179,6 +191,9 @@ def _get_op_output_var_names(op): ...@@ -179,6 +191,9 @@ def _get_op_output_var_names(op):
var_names = [] var_names = []
op_name = op.name() if isinstance(op, IrNode) \ op_name = op.name() if isinstance(op, IrNode) \
else op.type else op.type
if op_name not in _op_real_in_out_name:
return []
name_list = _op_real_in_out_name[op_name][1] name_list = _op_real_in_out_name[op_name][1]
for name in name_list: for name in name_list:
var_name = op.output(name) var_name = op.output(name)
...@@ -195,6 +210,9 @@ def _get_output_name_index(op, output_var_name): ...@@ -195,6 +210,9 @@ def _get_output_name_index(op, output_var_name):
"The input op should be IrNode or Operator." "The input op should be IrNode or Operator."
op_name = op.name() if isinstance(op, IrNode) \ op_name = op.name() if isinstance(op, IrNode) \
else op.type else op.type
if op_name not in _op_real_in_out_name:
return None
name_list = _op_real_in_out_name[op_name][1] name_list = _op_real_in_out_name[op_name][1]
res = None res = None
for name in name_list: for name in name_list:
......
...@@ -31,7 +31,7 @@ from paddle.fluid.dygraph.container import Sequential ...@@ -31,7 +31,7 @@ from paddle.fluid.dygraph.container import Sequential
from paddle.nn import Linear, Conv2D, Softmax from paddle.nn import Linear, Conv2D, Softmax
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.fluid.contrib.slim.quantization.imperative.quant_nn import QuantizedConv2D from paddle.nn.quant.quant_layers import QuantizedConv2D
from imperative_test_utils import fix_model_dict, ImperativeLenet from imperative_test_utils import fix_model_dict, ImperativeLenet
......
...@@ -20,7 +20,7 @@ import paddle ...@@ -20,7 +20,7 @@ import paddle
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization.imperative import quant_nn import paddle.nn.quant.quant_layers as quant_layers
paddle.enable_static() paddle.enable_static()
...@@ -45,7 +45,7 @@ class TestMovingAverageAbsMaxScaleOp(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestMovingAverageAbsMaxScaleOp(unittest.TestCase):
name='image', shape=[784], dtype='float32') name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
fc_tmp = fluid.layers.fc(image, size=10, act='softmax') fc_tmp = fluid.layers.fc(image, size=10, act='softmax')
out_scale = quant_nn.MovingAverageAbsMaxScale( out_scale = quant_layers.MovingAverageAbsMaxScale(
name=fc_tmp.name, dtype=fc_tmp.dtype) name=fc_tmp.name, dtype=fc_tmp.dtype)
fc_tmp_1 = out_scale(fc_tmp) fc_tmp_1 = out_scale(fc_tmp)
cross_entropy = fluid.layers.softmax_with_cross_entropy(fc_tmp, cross_entropy = fluid.layers.softmax_with_cross_entropy(fc_tmp,
......
...@@ -21,5 +21,6 @@ from .functional_layers import reshape # noqa: F401 ...@@ -21,5 +21,6 @@ from .functional_layers import reshape # noqa: F401
from .functional_layers import transpose # noqa: F401 from .functional_layers import transpose # noqa: F401
from .functional_layers import concat # noqa: F401 from .functional_layers import concat # noqa: F401
from .functional_layers import flatten # noqa: F401 from .functional_layers import flatten # noqa: F401
from .quant_layers import QuantStub # noqa: F401
__all__ = [] __all__ = []
...@@ -26,21 +26,103 @@ import logging ...@@ -26,21 +26,103 @@ import logging
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
__all__ = [ __all__ = [
'FakeQuantMovingAverageAbsMax',
'FakeQuantAbsMax', 'FakeQuantAbsMax',
'FakeQuantMovingAverageAbsMax',
'FakeQuantChannelWiseAbsMax', 'FakeQuantChannelWiseAbsMax',
'QuantizedConv2D', 'QuantizedConv2D',
'QuantizedLinear', 'QuantizedLinear',
'QuantizedNoweightLayer',
'MovingAverageAbsMaxScale', 'MovingAverageAbsMaxScale',
'MAOutputScaleLayer', 'MAOutputScaleLayer',
'FakeQuantMAOutputScaleLayer', 'FakeQuantMAOutputScaleLayer',
'QuantStub',
] ]
_logger = get_logger( _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class FakeQuantAbsMax(layers.Layer):
r"""
FakeQuantAbsMax layer does the abs_max quant and then dequant.
Its computational formula is described as below:
:math:`scale = max(abs(X))`
:math:`range = 2^{bit\_length - 1} - 1`
:math:`Out = round(X / scale * range) * scale / range`
"""
def __init__(self,
name=None,
quant_bits=8,
dtype='float32',
quant_on_weight=False):
super(FakeQuantAbsMax, self).__init__()
self._quant_bits = quant_bits
self._name = name
scale_prefix = "{}.scale".format(
name) if name else 'quant_dequant.scale'
self._scale_name = unique_name.generate(scale_prefix)
if quant_on_weight:
scale_attr = ParamAttr(
name=self._scale_name,
initializer=Constant(0.0),
trainable=False)
self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=self._dtype)
self._scale.stop_gradient = True
else:
self._scale = None
def forward(self, input):
if in_dygraph_mode():
attrs = ('bit_length', self._quant_bits)
quant_out = _varbase_creator(
type=input.type,
name="{}.quantized.dequantized".format(input.name),
shape=input.shape,
dtype=input.dtype,
persistable=False)
out_scale = self._scale
if not out_scale:
out_scale = _varbase_creator(
type=core.VarDesc.VarType.LOD_TENSOR,
name=self._scale_name,
shape=[1],
dtype=self._dtype,
persistable=False)
out_scale.stop_gradient = True
out, _, = core.ops.fake_quantize_dequantize_abs_max(
input, quant_out, out_scale, *attrs)
return out
check_variable_and_dtype(input, 'input', ['float32'], "FakeQuantAbsMax")
attrs = {'bit_length': self._quant_bits}
inputs = {"X": [input]}
quant_out = self._helper.create_variable(
name="{}.quantized.dequantized".format(input.name),
dtype=input.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
out_scale = self._scale
if not out_scale:
out_scale = self._helper.create_variable(
name=self._scale_name,
dtype=self._dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
outputs = {"Out": [quant_out], "OutScale": [out_scale]}
self._helper.append_op(
type="fake_quantize_dequantize_abs_max",
inputs=inputs,
outputs=outputs,
attrs=attrs)
return quant_out
class FakeQuantMovingAverageAbsMax(layers.Layer): class FakeQuantMovingAverageAbsMax(layers.Layer):
r""" r"""
FakeQuantMovingAverageAbsMax layer does the moving_average_abs_max quant and then dequant. FakeQuantMovingAverageAbsMax layer does the moving_average_abs_max quant and then dequant.
...@@ -64,7 +146,7 @@ class FakeQuantMovingAverageAbsMax(layers.Layer): ...@@ -64,7 +146,7 @@ class FakeQuantMovingAverageAbsMax(layers.Layer):
name) if name else 'quant_dequant.scale' name) if name else 'quant_dequant.scale'
scale_attr = ParamAttr( scale_attr = ParamAttr(
name=unique_name.generate(scale_prefix), name=unique_name.generate(scale_prefix),
initializer=Constant(0.001), initializer=Constant(0.),
trainable=False) trainable=False)
self._scale = self.create_parameter( self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=dtype) shape=[1], attr=scale_attr, dtype=dtype)
...@@ -74,7 +156,7 @@ class FakeQuantMovingAverageAbsMax(layers.Layer): ...@@ -74,7 +156,7 @@ class FakeQuantMovingAverageAbsMax(layers.Layer):
name) if name else 'quant_dequant.state' name) if name else 'quant_dequant.state'
state_attr = ParamAttr( state_attr = ParamAttr(
name=unique_name.generate(state_prefix), name=unique_name.generate(state_prefix),
initializer=Constant(1), initializer=Constant(0),
trainable=False) trainable=False)
self._state = self.create_parameter( self._state = self.create_parameter(
shape=[1], attr=state_attr, dtype=dtype) shape=[1], attr=state_attr, dtype=dtype)
...@@ -84,7 +166,7 @@ class FakeQuantMovingAverageAbsMax(layers.Layer): ...@@ -84,7 +166,7 @@ class FakeQuantMovingAverageAbsMax(layers.Layer):
name) if name else 'quant_dequant.accum' name) if name else 'quant_dequant.accum'
accum_attr = ParamAttr( accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix), name=unique_name.generate(accum_prefix),
initializer=Constant(1), initializer=Constant(0),
trainable=False) trainable=False)
self._accum = self.create_parameter( self._accum = self.create_parameter(
shape=[1], attr=accum_attr, dtype=dtype) shape=[1], attr=accum_attr, dtype=dtype)
...@@ -139,24 +221,21 @@ class FakeQuantMovingAverageAbsMax(layers.Layer): ...@@ -139,24 +221,21 @@ class FakeQuantMovingAverageAbsMax(layers.Layer):
return quant_out return quant_out
class FakeQuantAbsMax(layers.Layer): class FakeQuantChannelWiseAbsMax(layers.Layer):
r"""
FakeQuantAbsMax layer does the abs_max quant and then dequant.
Its computational formula is described as below:
:math:`scale = max(abs(X))`
:math:`range = 2^{bit\_length - 1} - 1`
:math:`Out = round(X / scale * range) * scale / range`
"""
def __init__(self, def __init__(self,
name=None, name=None,
channel_num=None,
quant_bits=8, quant_bits=8,
quant_axis=0,
dtype='float32', dtype='float32',
quant_on_weight=False): quant_on_weight=False):
super(FakeQuantAbsMax, self).__init__() assert quant_on_weight == True, "Channel_wise only can be used on weight quantization."
super(FakeQuantChannelWiseAbsMax, self).__init__()
self._quant_bits = quant_bits self._quant_bits = quant_bits
self._quant_axis = quant_axis
self._dtype = dtype
self._name = name self._name = name
self._channel_num = channel_num
scale_prefix = "{}.scale".format( scale_prefix = "{}.scale".format(
name) if name else 'quant_dequant.scale' name) if name else 'quant_dequant.scale'
self._scale_name = unique_name.generate(scale_prefix) self._scale_name = unique_name.generate(scale_prefix)
...@@ -166,35 +245,39 @@ class FakeQuantAbsMax(layers.Layer): ...@@ -166,35 +245,39 @@ class FakeQuantAbsMax(layers.Layer):
initializer=Constant(0.0), initializer=Constant(0.0),
trainable=False) trainable=False)
self._scale = self.create_parameter( self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=self._dtype) shape=[self._channel_num], attr=scale_attr, dtype=self._dtype)
self._scale.stop_gradient = True self._scale.stop_gradient = True
else: else:
self._scale = None self._scale = None
def forward(self, input): def forward(self, input):
if in_dygraph_mode(): if in_dygraph_mode():
attrs = ('bit_length', self._quant_bits) attrs = ('bit_length', self._quant_bits, 'quant_axis',
self._quant_axis)
quant_out = _varbase_creator( quant_out = _varbase_creator(
type=input.type, type=input.type,
name="{}.quantized.dequantized".format(input.name), name="{}.quantized.dequantized".format(input.name),
shape=input.shape, shape=input.shape,
dtype=input.dtype, dtype=input.dtype,
persistable=False) persistable=False)
out_scale = self._scale out_scale = self._scale
if not out_scale: if out_scale is None:
out_scale = _varbase_creator( out_scale = _varbase_creator(
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
name=self._scale_name, name=self._scale_name,
shape=[1], shape=[self._channel_num],
dtype=self._dtype, dtype=self._dtype,
persistable=False) persistable=False)
out_scale.stop_gradient = True out_scale.stop_gradient = True
out, _, = core.ops.fake_quantize_dequantize_abs_max(
out, _, = core.ops.fake_channel_wise_quantize_dequantize_abs_max(
input, quant_out, out_scale, *attrs) input, quant_out, out_scale, *attrs)
return out return out
check_variable_and_dtype(input, 'input', ['float32'], "FakeQuantAbsMax") check_variable_and_dtype(input, 'input', ['float32'],
attrs = {'bit_length': self._quant_bits} "FakeQuantChannelWiseAbsMax")
attrs = {'bit_length': self._quant_bits, 'quant_axis': self._quant_axis}
inputs = {"X": [input]} inputs = {"X": [input]}
quant_out = self._helper.create_variable( quant_out = self._helper.create_variable(
name="{}.quantized.dequantized".format(input.name), name="{}.quantized.dequantized".format(input.name),
...@@ -213,7 +296,7 @@ class FakeQuantAbsMax(layers.Layer): ...@@ -213,7 +296,7 @@ class FakeQuantAbsMax(layers.Layer):
outputs = {"Out": [quant_out], "OutScale": [out_scale]} outputs = {"Out": [quant_out], "OutScale": [out_scale]}
self._helper.append_op( self._helper.append_op(
type="fake_quantize_dequantize_abs_max", type="fake_channel_wise_quantize_dequantize_abs_max",
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
attrs=attrs) attrs=attrs)
...@@ -221,82 +304,83 @@ class FakeQuantAbsMax(layers.Layer): ...@@ -221,82 +304,83 @@ class FakeQuantAbsMax(layers.Layer):
return quant_out return quant_out
class FakeQuantChannelWiseAbsMax(layers.Layer): class MovingAverageAbsMaxScale(layers.Layer):
def __init__(self, def __init__(self, name=None, moving_rate=0.9, dtype='float32'):
name=None, r"""
channel_num=None, MovingAverageMaxScale layer is used to calculating the output quantization
quant_bits=8, scale of Layer. Its computational formula is described as below:
quant_axis=0,
dtype='float32', :math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)`
quant_on_weight=False): :math:`Out = X`
assert quant_on_weight == True, "Channel_wise only can be used on weight quantization." """
super(FakeQuantChannelWiseAbsMax, self).__init__() super(MovingAverageAbsMaxScale, self).__init__()
self._quant_bits = quant_bits self._moving_rate = moving_rate
self._quant_axis = quant_axis
self._dtype = dtype scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
self._name = name scale_name = unique_name.generate(scale_prefix)
self._channel_num = channel_num
scale_prefix = "{}.scale".format(
name) if name else 'quant_dequant.scale'
self._scale_name = unique_name.generate(scale_prefix)
if quant_on_weight:
scale_attr = ParamAttr( scale_attr = ParamAttr(
name=self._scale_name, name=scale_name, initializer=Constant(0), trainable=False)
initializer=Constant(0.0),
trainable=False)
self._scale = self.create_parameter( self._scale = self.create_parameter(
shape=[self._channel_num], attr=scale_attr, dtype=self._dtype) shape=[1], attr=scale_attr, dtype=dtype)
self._scale.stop_gradient = True self._scale.stop_gradient = True
else:
self._scale = None state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(0),
trainable=False)
self._state = self.create_parameter(
shape=[1], attr=state_attr, dtype=dtype)
self._state.stop_gradient = True
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(0),
trainable=False)
self._accum = self.create_parameter(
shape=[1], attr=accum_attr, dtype=dtype)
self._accum.stop_gradient = True
def forward(self, input): def forward(self, input):
if in_dygraph_mode(): if in_dygraph_mode():
attrs = ('bit_length', self._quant_bits, 'quant_axis', attrs = ('moving_rate', self._moving_rate, 'is_test',
self._quant_axis) not self.training)
state = self._state if self.training else None
accum = self._accum if self.training else None
quant_out = _varbase_creator( quant_out = _varbase_creator(
type=input.type, type=input.type,
name="{}.quantized.dequantized".format(input.name), name="{}.tmp".format(input.name),
shape=input.shape, shape=input.shape,
dtype=input.dtype, dtype=input.dtype,
persistable=False) persistable=False)
out_scale = self._scale out, _, _, _ = core.ops.moving_average_abs_max_scale(
if out_scale is None: input, accum, state, quant_out, self._scale, state, accum,
out_scale = _varbase_creator( *attrs)
type=core.VarDesc.VarType.LOD_TENSOR,
name=self._scale_name,
shape=[self._channel_num],
dtype=self._dtype,
persistable=False)
out_scale.stop_gradient = True
out, _, = core.ops.fake_channel_wise_quantize_dequantize_abs_max(
input, quant_out, out_scale, *attrs)
return out return out
check_variable_and_dtype(input, 'input', ['float32'], check_variable_and_dtype(input, 'input', ['float32', 'float64'],
"FakeQuantChannelWiseAbsMax") 'MovingAverageAbsMaxScale')
attrs = {'bit_length': self._quant_bits, 'quant_axis': self._quant_axis}
attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training}
inputs = {"X": [input]} inputs = {"X": [input]}
quant_out = self._helper.create_variable( quant_out = self._helper.create_variable(
name="{}.quantized.dequantized".format(input.name), name="{}.tmp".format(input.name),
dtype=input.dtype, dtype=input.dtype,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=False) stop_gradient=False)
out_scale = self._scale outputs = {"Out": [quant_out], "OutScale": [self._scale]}
if not out_scale:
out_scale = self._helper.create_variable( if self.training:
name=self._scale_name, inputs['InState'] = [self._state]
dtype=self._dtype, inputs['InAccum'] = [self._accum]
type=core.VarDesc.VarType.LOD_TENSOR, outputs['OutState'] = [self._state]
persistable=False, outputs['OutAccum'] = [self._accum]
stop_gradient=True)
outputs = {"Out": [quant_out], "OutScale": [out_scale]}
self._helper.append_op( self._helper.append_op(
type="fake_channel_wise_quantize_dequantize_abs_max", type="moving_average_abs_max_scale",
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
attrs=attrs) attrs=attrs)
...@@ -304,31 +388,7 @@ class FakeQuantChannelWiseAbsMax(layers.Layer): ...@@ -304,31 +388,7 @@ class FakeQuantChannelWiseAbsMax(layers.Layer):
return quant_out return quant_out
def _get_fake_quant_type(quant_type, **kwargs): QuantStub = MovingAverageAbsMaxScale
call_args = {
"name": kwargs.get("name", None),
"quant_bits": kwargs.get("quant_bits", 8),
"dtype": kwargs.get("dtype", "float32")
}
if quant_type == 'abs_max':
call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
elif quant_type == 'moving_average_abs_max':
call_args["moving_rate"] = kwargs.get("moving_rate", 0.9)
elif quant_type == 'channel_wise_abs_max':
call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
call_args["channel_num"] = kwargs.get("channel_num", None)
call_args["quant_axis"] = kwargs.get("quant_axis", 0)
assert call_args["channel_num"] is not None, (
"You need to input channel_num"
"when you use channel_wise_abs_max strategy.")
fake_quant_map = {
'abs_max': FakeQuantAbsMax,
'moving_average_abs_max': FakeQuantMovingAverageAbsMax,
'channel_wise_abs_max': FakeQuantChannelWiseAbsMax
}
return fake_quant_map[quant_type](**call_args)
class QuantizedConv2D(layers.Layer): class QuantizedConv2D(layers.Layer):
...@@ -489,117 +549,10 @@ class QuantizedLinear(layers.Layer): ...@@ -489,117 +549,10 @@ class QuantizedLinear(layers.Layer):
return out return out
class QuantizedNoweightLayer(layers.Layer):
def __init__(self,
layer,
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
*args,
**kwargs):
super(QuantizedNoweightLayer, self).__init__()
self._layer = layer
self._fake_quant_input = _get_fake_quant_type(
'moving_average_abs_max',
name=layer.full_name(),
moving_rate=moving_rate,
quant_bits=activation_bits,
dtype=self._dtype,
quant_on_weight=False)
def forward(self, input):
return self._layer.forward(self._fake_quant_input(input))
class MovingAverageAbsMaxScale(layers.Layer):
def __init__(self, name=None, moving_rate=0.9, dtype='float32'):
r"""
MovingAverageMaxScale layer is used to calculating the output quantization
scale of Layer. Its computational formula is described as below:
:math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)`
:math:`Out = X`
"""
super(MovingAverageAbsMaxScale, self).__init__()
self._moving_rate = moving_rate
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
scale_name = unique_name.generate(scale_prefix)
scale_attr = ParamAttr(
name=scale_name, initializer=Constant(1), trainable=False)
self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=dtype)
self._scale.stop_gradient = True
state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
self._state = self.create_parameter(
shape=[1], attr=state_attr, dtype=dtype)
self._state.stop_gradient = True
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
self._accum = self.create_parameter(
shape=[1], attr=accum_attr, dtype=dtype)
self._accum.stop_gradient = True
def forward(self, input):
if in_dygraph_mode():
attrs = ('moving_rate', self._moving_rate, 'is_test',
not self.training)
state = self._state if self.training else None
accum = self._accum if self.training else None
quant_out = _varbase_creator(
type=input.type,
name="{}.tmp".format(input.name),
shape=input.shape,
dtype=input.dtype,
persistable=False)
out, _, _, _ = core.ops.moving_average_abs_max_scale(
input, accum, state, quant_out, self._scale, state, accum,
*attrs)
return out
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'MovingAverageAbsMaxScale')
attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training}
inputs = {"X": [input]}
quant_out = self._helper.create_variable(
name="{}.tmp".format(input.name),
dtype=input.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
outputs = {"Out": [quant_out], "OutScale": [self._scale]}
if self.training:
inputs['InState'] = [self._state]
inputs['InAccum'] = [self._accum]
outputs['OutState'] = [self._state]
outputs['OutAccum'] = [self._accum]
self._helper.append_op(
type="moving_average_abs_max_scale",
inputs=inputs,
outputs=outputs,
attrs=attrs)
return quant_out
class MAOutputScaleLayer(layers.Layer): class MAOutputScaleLayer(layers.Layer):
""" """
Calculate the scale (moving average abs max) for the output of the input layer.
Add MovingAverageMaxScale layer to the behind of the input layer. Add MovingAverageMaxScale layer to the behind of the input layer.
Calculate the scale (moving average abs max) for the output of the input layer.
""" """
def __init__(self, layer=None, moving_rate=0.9, name=None, dtype='float32'): def __init__(self, layer=None, moving_rate=0.9, name=None, dtype='float32'):
...@@ -623,6 +576,10 @@ class MAOutputScaleLayer(layers.Layer): ...@@ -623,6 +576,10 @@ class MAOutputScaleLayer(layers.Layer):
class FakeQuantMAOutputScaleLayer(layers.Layer): class FakeQuantMAOutputScaleLayer(layers.Layer):
"""
Add FakeQuantMovingAverageAbsMax layer to the behind of the input layer.
"""
def __init__(self, def __init__(self,
layer, layer,
weight_bits=8, weight_bits=8,
...@@ -649,3 +606,30 @@ class FakeQuantMAOutputScaleLayer(layers.Layer): ...@@ -649,3 +606,30 @@ class FakeQuantMAOutputScaleLayer(layers.Layer):
return out return out
else: else:
return self._fake_quant_output(out) return self._fake_quant_output(out)
def _get_fake_quant_type(quant_type, **kwargs):
call_args = {
"name": kwargs.get("name", None),
"quant_bits": kwargs.get("quant_bits", 8),
"dtype": kwargs.get("dtype", "float32")
}
if quant_type == 'abs_max':
call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
elif quant_type == 'moving_average_abs_max':
call_args["moving_rate"] = kwargs.get("moving_rate", 0.9)
elif quant_type == 'channel_wise_abs_max':
call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
call_args["channel_num"] = kwargs.get("channel_num", None)
call_args["quant_axis"] = kwargs.get("quant_axis", 0)
assert call_args["channel_num"] is not None, (
"You need to input channel_num"
"when you use channel_wise_abs_max strategy.")
fake_quant_map = {
'abs_max': FakeQuantAbsMax,
'moving_average_abs_max': FakeQuantMovingAverageAbsMax,
'channel_wise_abs_max': FakeQuantChannelWiseAbsMax
}
return fake_quant_map[quant_type](**call_args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册