未验证 提交 548cdbc5 编写于 作者: Z Zhen Wang 提交者: GitHub

Quantization-aware training for dygraph (#24634)

* Add the imperative quantization aware training.
*  This is the python part of Imperative QAT. test=develop
上级 0b54d54f
......@@ -219,14 +219,14 @@ class FakeQuantOrWithDequantAbsMaxOpMaker
bit_length));
});
AddComment(R"DOC(
This is a Base Op which support FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
FakeQuantAbsMaxOp operator is used in the dynamic quantization.
$$scale = max(abs(X))$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
FakeQuantDequantAbsMaxOp operator do the abs_max quant and then dequant.
FakeQuantDequantAbsMaxOp operator does the abs_max quantization and then dequantization.
$$scale = max(abs(X))$$
$$range = 2^{bit\_length - 1} - 1$$
......@@ -423,14 +423,14 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC(
This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp.
This is a Base Op which supports FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp.
FakeQuantMovingAverageAbsMaxOp operator is used in the static quantization.
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range)$$
FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max quant and then dequant.
FakeQuantDequantMovingAverageAbsMaxOp operator does the moving_average_abs_max quant and then dequant.
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$
......@@ -505,15 +505,12 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
"FakeQuantDequantGradOp");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
"FakeQuantDequantGradOp");
auto x_grad_name = framework::GradVarName("X");
PADDLE_ENFORCE_EQ(
ctx->HasOutput(x_grad_name), true,
platform::errors::PreconditionNotMet(
"FakeQuantDequantGradOp doesn't have the output named %s.",
x_grad_name));
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
}
......
......@@ -26,9 +26,12 @@ from . import quant2_int8_mkldnn_pass
from .quant2_int8_mkldnn_pass import *
from . import post_training_quantization
from .post_training_quantization import *
from . import imperative
from .imperative import *
__all__ = quantization_pass.__all__ + quantization_strategy.__all__
__all__ += mkldnn_post_training_strategy.__all__
__all__ += quant_int8_mkldnn_pass.__all__
__all__ += quant2_int8_mkldnn_pass.__all__
__all__ += post_training_quantization.__all__
__all__ += imperative.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import quant_nn
from .quant_nn import *
from . import qat
from .qat import *
__all__ = []
__all__ += quant_nn.__all__
__all__ += qat.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import sys
from paddle.fluid import dygraph
from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.log_helper import get_logger
from . import quant_nn
__all__ = ['ImperativeQuantAware']
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class ImperativeQuantAware(object):
"""
Add the fake quant logic for given quantizable layers, namely add the quant_dequant
computational logic both for activation inputs and weight inputs.
"""
def __init__(self,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max',
moving_rate=0.9,
quantizable_layer_type=['Conv2D', 'Linear']):
"""
The constructor for ImperativeQuantAware.
Args:
weight_bits(int): quantization bit number for weights,
whereas the bias is not quantized.
activation_bits(int): quantization bit number for activations.
weight_quantize_type(str): quantization type for weights,
which supports 'abs_max' now. The 'moving_average_abs_max'
usually is not used for weights, since weights are fixed once the
model is well trained.
activation_quantize_type(str): quantization type for activations,
which supports 'abs_max' and 'moving_average_abs_max' now.
If using 'abs_max' mode, the quantization scale will be calculated
dynamically each step in both training and testing period. If using
'moving_average_abs_max', the static quantization scale will be calculated
during training and used in inference.
moving_rate(float): the parameter for 'moving_average_abs_max' quantization.
quantizable_op_type(list[str]): List the type of layers that will be quantized.
Default is ['Conv2D', 'Linear']. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
Examples:
.. code-block:: python
from paddle.fluid.contrib.slim.quantization \
import ImperativeQuantAware
from paddle.incubate.hapi.vision.models \
import resnet
model = resnet.resnet50(pretrained=True)
imperative_qat = ImperativeQuantAware(
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max')
# Add the fake quant logical.
# The original model will be rewrite.
imperative_qat.quantize(model)
# Fine-tune the quantized model
# ...
# Save quant model for the inference.
imperative_qat.save_quantized_model(
dirname="./resnet50_qat",
model=model,
input_shape=[(3, 224, 224)],
input_dtype=['float32'],
feed=[0],
fetch=[0])
"""
super(ImperativeQuantAware, self).__init__()
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._moving_rate = moving_rate
quant_type = {'abs_max', 'moving_average_abs_max'}
if activation_quantize_type not in quant_type:
raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be "
"'abs_max' or 'moving_average_abs_max' now." %
(str(activation_quantize_type)))
if weight_quantize_type not in quant_type:
raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be "
"'abs_max' or 'moving_average_abs_max' now." %
(str(weight_quantize_type)))
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._quant_layers_map = {'Conv2D': Conv2D, 'Linear': Linear}
self._quantizable_layer_type = tuple(
self._quant_layers_map[layer]
if layer in self._quant_layers_map else layer
for layer in quantizable_layer_type)
for layer in self._quantizable_layer_type:
assert not isinstance(
layer, str), "{} is unspported to be quantized.".format(layer)
def quantize(self, model):
"""
According to weights' and activations' quantization types, the model will be added some fake
quant ops, such as fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_abs_max
and so on.
Args:
model(fluid.dygraph.Layer): the model to be quantized.
Returns:
None
"""
for name, layer in model.named_sublayers():
if not isinstance(layer, self._quantizable_layer_type):
continue
scopes = name.split('.')
target = scopes[-1]
obj = model
parent = model
for i in range(len(scopes) - 1):
obj = getattr(parent, scopes[i])
parent = obj
quant_layer = self._get_quantized_counterpart(layer)
setattr(obj, target, quant_layer)
def save_quantized_model(self,
dirname,
model,
input_shape,
input_dtype,
feed,
fetch,
append_batch_size=True):
"""
Save the quantized model for the inference.
Args:
dirname (str): the directory to save the quantized model.
model(fluid.dygraph.Layer): the quantized model to be saved.
input_shape(list[tuple(int)]): The shape value for each input,
e.g. [(3, 224, 224)].
input_dtype(list[str]): The dtype value for each input,
e.g. ['float32'].
feed(list[int]): the indices of the input variables of the
imperative functions which will be saved as input variables in
inference model.
fetch(list[int]): the indices of the returned variable of the
imperative functions which will be saved as output variables in
inference model.
append_batch_size(bool, optional):
If true, it prepends an extra axis to the input_shape, meanwhile,
the input_shape shouldn't contain the batch size dimension.
Otherwise, it just uses the input_shape. Default True.
Returns:
None
"""
assert isinstance(
input_shape, list), "The parameter `input_shape` shoubld be a list."
assert isinstance(
input_dtype, list), "The parameter `input_dtype` shoubld be a list."
assert isinstance(feed, list), "The parameter `feed` shoubld be a list."
assert isinstance(fetch,
list), "The parameter `fetch` shoubld be a list."
assert len(input_shape) == len(
input_dtype
), "The length of input_shape should be equal to input_dtype's."
assert len(input_dtype) == len(
feed), "The length of input_shape should be equal to feed's."
def _convert(model, *args):
return model(*args)
prog_trans = dygraph.ProgramTranslator()
with dygraph.guard():
model.eval()
input_vars = []
for shape, dtype in zip(input_shape, input_dtype):
raw_data = np.random.random(shape)
input_data = raw_data[np.newaxis, :].astype(
dtype) if append_batch_size else raw_data.astype(dtype)
input_var = dygraph.to_variable(input_data)
input_vars.append(input_var)
prog_trans.get_output(_convert, model, *input_vars)
prog_trans.save_inference_model(dirname, feed, fetch)
def _get_quantized_counterpart(self, layer):
quant_layers = tuple(self._quant_layers_map.values())
quantized_counterpart = tuple('Quantized' + k
for k in self._quant_layers_map.keys())
predicate = lambda value: isinstance(layer, value)
index_generator = (i for i, v in enumerate(quant_layers)
if predicate(v))
try:
index = next(index_generator)
except StopIteration:
_logger.fatal("The layer {} is unsupported to be quantized.".format(
layer.full_name()))
sys.exit(-1)
quantized_layer = quant_nn.__dict__[quantized_counterpart[index]](
layer, self._weight_bits, self._activation_bits, self._moving_rate,
self._weight_quantize_type, self._activation_quantize_type)
return quantized_layer
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.dygraph import layers
from paddle.fluid import core
from paddle.fluid import dygraph_utils
from paddle.fluid import unique_name
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import _varbase_creator
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.initializer import Constant
from paddle.fluid.data_feeder import check_variable_and_dtype
__all__ = [
'FakeQuantMovingAverage', 'FakeQuantAbsMax', 'QuantizedConv2D',
'QuantizedLinear'
]
class FakeQuantMovingAverage(layers.Layer):
"""
FakeQuantMovingAverage layer does the moving_average_abs_max quant and then dequant.
Its computational formula is described as below:
:math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)`
:math:`range = 2^{bit\_length - 1} - 1`
:math:`Out = round(X / scale * range) * scale / range`
"""
def __init__(self,
name=None,
moving_rate=0.9,
quant_bits=8,
dtype='float32'):
super(FakeQuantMovingAverage, self).__init__()
self._moving_rate = moving_rate
self._quant_bits = quant_bits
scale_prefix = "{}.scale".format(
name) if name else 'quant_dequant.scale'
scale_attr = ParamAttr(
name=unique_name.generate(scale_prefix),
initializer=Constant(0.001),
trainable=False)
self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=dtype)
self._scale.stop_gradient = True
state_prefix = "{}.state".format(
name) if name else 'quant_dequant.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
self._state = self.create_parameter(
shape=[1], attr=state_attr, dtype=dtype)
self._state.stop_gradient = True
accum_prefix = "{}.accum".format(
name) if name else 'quant_dequant.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
self._accum = self.create_parameter(
shape=[1], attr=accum_attr, dtype=dtype)
self._accum.stop_gradient = True
def forward(self, input):
if in_dygraph_mode():
attrs = ('moving_rate', self._moving_rate, 'bit_length',
self._quant_bits, 'is_test', not self.training)
quant_out = _varbase_creator(
type=input.type,
name="{}.quantized.dequantized".format(input.name),
shape=input.shape,
dtype=input.dtype,
persistable=False)
state = self._state if self.training else None
accum = self._accum if self.training else None
out, _, _, _ = core.ops.fake_quantize_dequantize_moving_average_abs_max(
input, self._scale, accum, state, quant_out, self._scale, state,
accum, *attrs)
return out
check_variable_and_dtype(input, 'input', ['float32'],
"FakeQuantMovingAverage")
attrs = {
'moving_rate': self._moving_rate,
'bit_length': self._quant_bits,
'is_test': not self.training
}
inputs = {"X": [input], "InScale": [self._scale]}
quant_out = self._helper.create_variable(
name="{}.quantized.dequantized".format(input.name),
dtype=input.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
outputs = {"Out": [quant_out], "OutScale": [self._scale]}
if self.training:
inputs['InState'] = [self._state]
inputs['InAccum'] = [self._accum]
outputs['OutState'] = [self._state]
outputs['OutAccum'] = [self._accum]
self._helper.append_op(
type="fake_quantize_dequantize_moving_average_abs_max",
inputs=inputs,
outputs=outputs,
attrs=attrs)
return quant_out
class FakeQuantAbsMax(layers.Layer):
"""
FakeQuantAbsMax layer does the abs_max quant and then dequant.
Its computational formula is described as below:
:math:`scale = max(abs(X))`
:math:`range = 2^{bit\_length - 1} - 1`
:math:`Out = round(X / scale * range) * scale / range`
"""
def __init__(self,
name=None,
quant_bits=8,
dtype='float32',
quant_on_weight=False):
super(FakeQuantAbsMax, self).__init__()
self._quant_bits = quant_bits
self._dtype = dtype
self._name = name
scale_prefix = "{}.scale".format(
name) if name else 'quant_dequant.scale'
self._scale_name = unique_name.generate(scale_prefix)
if quant_on_weight:
scale_attr = ParamAttr(
name=self._scale_name,
initializer=Constant(0.0),
trainable=False)
self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=self._dtype)
self._scale.stop_gradient = True
else:
self._scale = None
def forward(self, input):
if in_dygraph_mode():
attrs = ('bit_length', self._quant_bits)
quant_out = _varbase_creator(
type=input.type,
name="{}.quantized.dequantized".format(input.name),
shape=input.shape,
dtype=input.dtype,
persistable=False)
out_scale = self._scale
if not out_scale:
out_scale = _varbase_creator(
type=core.VarDesc.VarType.LOD_TENSOR,
name=self._scale_name,
shape=[1],
dtype=self._dtype,
persistable=False)
out_scale.stop_gradient = True
out, _, = core.ops.fake_quantize_dequantize_abs_max(
input, quant_out, out_scale, *attrs)
return out
check_variable_and_dtype(input, 'input', ['float32'], "FakeQuantAbsMax")
attrs = {'bit_length': self._quant_bits}
inputs = {"X": [input]}
quant_out = self._helper.create_variable(
name="{}.quantized.dequantized".format(input.name),
dtype=input.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
out_scale = self._scale
if not out_scale:
out_scale = self._helper.create_variable(
name=self._scale_name,
dtype=self._dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
outputs = {"Out": [quant_out], "OutScale": [out_scale]}
self._helper.append_op(
type="fake_quantize_dequantize_abs_max",
inputs=inputs,
outputs=outputs,
attrs=attrs)
return quant_out
def _get_fake_quant_type(quant_type, name, moving_rate, quant_bits, dtype,
quant_on_weight):
fake_quant_map = {
'abs_max':
lambda: FakeQuantAbsMax(name, quant_bits, dtype, quant_on_weight),
'moving_average_abs_max':
lambda: FakeQuantMovingAverage(name, moving_rate, quant_bits, dtype)
}
return fake_quant_map[quant_type]()
class QuantizedConv2D(layers.Layer):
"""
The computational logic of QuantizedConv2D is the same with Conv2D.
The only difference is that its inputs are all fake quantized.
"""
def __init__(self,
layer,
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
weight_quantize_type='abs_max',
activation_quantize_type='abs_max'):
super(QuantizedConv2D, self).__init__()
# For Conv2D
self._groups = getattr(layer, '_groups')
self._stride = getattr(layer, '_stride')
self._padding = getattr(layer, '_padding')
self._dilation = getattr(layer, '_dilation')
self._act = getattr(layer, '_act')
self._use_cudnn = getattr(layer, '_use_cudnn')
self._dtype = getattr(layer, '_dtype')
self._l_type = getattr(layer, '_l_type')
self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias')
# For FakeQuant
self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type, self.weight.name, moving_rate, weight_bits,
self._dtype, True)
self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type,
layer.full_name(), moving_rate, activation_bits, self._dtype, False)
def forward(self, input):
quant_input = self._fake_quant_input(input)
quant_weight = self._fake_quant_weight(self.weight)
if in_dygraph_mode() and self._l_type == 'conv2d':
attrs = ('strides', self._stride, 'paddings', self._padding,
'dilations', self._dilation, 'groups', self._groups
if self._groups else 1, 'use_cudnn', self._use_cudnn)
pre_bias = core.ops.conv2d(quant_input, quant_weight, *attrs)
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, self.bias,
1)
return dygraph_utils._append_activation_in_dygraph(pre_act,
self._act)
check_variable_and_dtype(quant_input, 'input',
['float16', 'float32', 'float64'],
'QuantizedConv2D')
attrs = {
'strides': self._stride,
'paddings': self._padding,
'dilations': self._dilation,
'groups': self._groups if self._groups else 1,
'use_cudnn': self._use_cudnn,
'use_mkldnn': False,
}
pre_bias = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type=self._l_type,
inputs={
'Input': quant_input,
'Filter': quant_weight,
},
outputs={"Output": pre_bias},
attrs=attrs)
if self.bias is not None:
pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias],
'Y': [self.bias]},
outputs={'Out': [pre_act]},
attrs={'axis': 1})
else:
pre_act = pre_bias
return self._helper.append_activation(pre_act, act=self._act)
class QuantizedLinear(layers.Layer):
"""
The computational logic of QuantizedLinear is the same with Linear.
The only difference is that its inputs are all fake quantized.
"""
def __init__(self,
layer,
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
weight_quantize_type='abs_max',
activation_quantize_type='abs_max'):
super(QuantizedLinear, self).__init__()
# For Linear
self._act = getattr(layer, '_act')
self._dtype = getattr(layer, '_dtype')
self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias')
# For FakeQuant
self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type, self.weight.name, moving_rate, weight_bits,
self._dtype, True)
self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type,
layer.full_name(), moving_rate, activation_bits, self._dtype, False)
def forward(self, input):
quant_input = self._fake_quant_input(input)
quant_weight = self._fake_quant_weight(self.weight)
if in_dygraph_mode():
pre_bias = _varbase_creator(dtype=input.dtype)
core.ops.matmul(quant_input, quant_weight, pre_bias, 'transpose_X',
False, 'transpose_Y', False, "alpha", 1)
pre_act = dygraph_utils._append_bias_in_dygraph(
pre_bias, self.bias, axis=len(input.shape) - 1)
return dygraph_utils._append_activation_in_dygraph(pre_act,
self._act)
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'],
"QuantizedLinear")
attrs = {
"transpose_X": False,
"transpose_Y": False,
"alpha": 1,
}
inputs = {"X": [quant_input], "Y": [quant_weight]}
mul_out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="matmul",
inputs=inputs,
outputs={"Out": [mul_out]},
attrs=attrs)
if self.bias is not None:
pre_activation = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [mul_out],
'Y': [self.bias]},
outputs={'Out': [pre_activation]},
attrs={'axis': len(input.shape) - 1})
else:
pre_activation = mul_out
return self._helper.append_activation(pre_activation, act=self._act)
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import os
import numpy as np
import random
import unittest
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.log_helper import get_logger
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def StaticLenet(data, num_classes=10, classifier_activation='softmax'):
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
conv1 = fluid.layers.conv2d(
data,
num_filters=6,
filter_size=3,
stride=1,
padding=1,
param_attr=conv2d_w1_attr,
bias_attr=conv2d_b1_attr)
pool1 = fluid.layers.pool2d(
conv1, pool_size=2, pool_type='max', pool_stride=2)
conv2 = fluid.layers.conv2d(
pool1,
num_filters=16,
filter_size=5,
stride=1,
padding=0,
param_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr)
pool2 = fluid.layers.pool2d(
conv2, pool_size=2, pool_type='max', pool_stride=2)
fc1 = fluid.layers.fc(input=pool2,
size=120,
param_attr=fc_w1_attr,
bias_attr=fc_b1_attr)
fc2 = fluid.layers.fc(input=fc1,
size=84,
param_attr=fc_w2_attr,
bias_attr=fc_b2_attr)
fc3 = fluid.layers.fc(input=fc2,
size=num_classes,
act=classifier_activation,
param_attr=fc_w3_attr,
bias_attr=fc_b3_attr)
return fc3
class ImperativeLenet(fluid.dygraph.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'):
super(ImperativeLenet, self).__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
self.features = Sequential(
Conv2D(
num_channels=1,
num_filters=6,
filter_size=3,
stride=1,
padding=1,
param_attr=conv2d_w1_attr,
bias_attr=conv2d_b1_attr),
Pool2D(
pool_size=2, pool_type='max', pool_stride=2),
Conv2D(
num_channels=6,
num_filters=16,
filter_size=5,
stride=1,
padding=0,
param_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr),
Pool2D(
pool_size=2, pool_type='max', pool_stride=2))
self.fc = Sequential(
Linear(
input_dim=400,
output_dim=120,
param_attr=fc_w1_attr,
bias_attr=fc_b1_attr),
Linear(
input_dim=120,
output_dim=84,
param_attr=fc_w2_attr,
bias_attr=fc_b2_attr),
Linear(
input_dim=84,
output_dim=num_classes,
act=classifier_activation,
param_attr=fc_w3_attr,
bias_attr=fc_b3_attr))
def forward(self, inputs):
x = self.features(inputs)
x = fluid.layers.flatten(x, 1)
x = self.fc(x)
return x
class TestImperativeQat(unittest.TestCase):
"""
QAT = quantization-aware training
"""
def test_qat_save(self):
imperative_qat = ImperativeQuantAware(
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max')
with fluid.dygraph.guard():
lenet = ImperativeLenet()
imperative_qat.quantize(lenet)
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=lenet.parameters())
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=32, drop_last=True)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32)
epoch_num = 1
for epoch in range(epoch_num):
lenet.train()
for batch_id, data in enumerate(train_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
acc = fluid.layers.accuracy(out, label)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
lenet.clear_gradients()
if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".
format(epoch, batch_id,
avg_loss.numpy(), acc.numpy()))
lenet.eval()
for batch_id, data in enumerate(test_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
acc_top1 = fluid.layers.accuracy(
input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(
input=out, label=label, k=5)
if batch_id % 100 == 0:
_logger.info(
"Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}".
format(epoch, batch_id,
acc_top1.numpy(), acc_top5.numpy()))
# save weights
model_dict = lenet.state_dict()
fluid.save_dygraph(model_dict, "save_temp")
# test the correctness of `save_quantized_model`
data = next(test_reader())
test_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
test_img = fluid.dygraph.to_variable(test_data)
lenet.eval()
before_save = lenet(test_img)
# save inference quantized model
path = "./mnist_infer_model"
imperative_qat.save_quantized_model(
dirname=path,
model=lenet,
input_shape=[(1, 28, 28)],
input_dtype=['float32'],
feed=[0],
fetch=[0])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
[inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=path, executor=exe))
after_save, = exe.run(inference_program,
feed={feed_target_names[0]: test_data},
fetch_list=fetch_targets)
self.assertTrue(
np.allclose(after_save, before_save.numpy()),
msg='Failed to save the inference quantized model.')
if __name__ == '__main__':
unittest.main()
......@@ -168,6 +168,7 @@ packages=['paddle',
'paddle.fluid.contrib.slim.graph',
'paddle.fluid.contrib.slim.prune',
'paddle.fluid.contrib.slim.quantization',
'paddle.fluid.contrib.slim.quantization.imperative',
'paddle.fluid.contrib.slim.distillation',
'paddle.fluid.contrib.slim.nas',
'paddle.fluid.contrib.slim.searcher',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册