未验证 提交 63840227 编写于 作者: G guofei 提交者: GitHub

Integrate ImperativeOutScale into ImperativeQuantAware. (#27956)

* Optimiz the unittest test_imperative_out_scale

test=develop
上级 b9e76a01
...@@ -23,9 +23,10 @@ from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX ...@@ -23,9 +23,10 @@ from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn import Linear, Conv2D from paddle.nn import Linear, Conv2D
from paddle.fluid.dygraph.nn import BatchNorm, Pool2D, Conv2DTranspose from paddle.fluid.dygraph.nn import BatchNorm, Pool2D, Conv2DTranspose
from paddle.fluid.io import load_inference_model, save_inference_model from paddle.fluid.io import load_inference_model, save_inference_model
from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6, Tanh, Softmax, PReLU from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6, Tanh, Softmax, PReLU, Swish
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from . import quant_nn from . import quant_nn
from .. import quantization_pass
__all__ = ['ImperativeQuantAware', 'ImperativeCalcOutScale'] __all__ = ['ImperativeQuantAware', 'ImperativeCalcOutScale']
...@@ -45,6 +46,7 @@ _op_real_in_out_name = { ...@@ -45,6 +46,7 @@ _op_real_in_out_name = {
"tanh": [["X"], ["Out"]], "tanh": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]], "batch_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Out"]], "sigmoid": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
} }
...@@ -109,7 +111,12 @@ class ImperativeQuantAware(object): ...@@ -109,7 +111,12 @@ class ImperativeQuantAware(object):
activation and returns dequantized activation. If None, will use activation and returns dequantized activation. If None, will use
quantization op defined by 'activation_quantize_type'. Default is None. quantization op defined by 'activation_quantize_type'. Default is None.
Examples: Note:
If user sets attribute 'skip_quant' to a Layer that support dynamic quantization and sets
it to true, the layer would not be quantized during training. If this attribute is not sets
or the attribute is false, the Layer would be qunatized in training.
Examples 1:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -126,18 +133,62 @@ class ImperativeQuantAware(object): ...@@ -126,18 +133,62 @@ class ImperativeQuantAware(object):
# Add the fake quant logical. # Add the fake quant logical.
# The original model will be rewrite. # The original model will be rewrite.
# The outscale of outputs in supportted layers would be calculated.
imperative_qat.quantize(model) imperative_qat.quantize(model)
# Fine-tune the quantized model # Fine-tune the quantized model
# ... # ...
# Save quant model for the inference. # Save quant model for the inference.
paddle.jit.save( imperative_qat.save_quantized_model(
layer=model, layer=model,
model_path="./resnet50_qat", model_path="./resnet50_qat",
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype='float32')]) shape=[None, 3, 224, 224], dtype='float32')])
Examples 2:
.. code-block:: python
import paddle
from paddle.fluid.contrib.slim.quantization \
import ImperativeQuantAware
class ImperativeModel(paddle.nn.Layer):
def __init__(self):
super(ImperativeModel, self).__init__()
# self.linear_0 would skip the quantization.
self.linear_0 = paddle.nn.Linear(784, 400)
self.linear_0.skip_quant = True
# self.linear_1 would not skip the quantization.
self.linear_1 = paddle.nn.Linear(400, 10)
self.linear_1.skip_quant = False
def forward(self, inputs):
x = self.linear_0(inputs)
x = self.linear_1(inputs)
return x
model = ImperativeModel()
imperative_qat = ImperativeQuantAware(
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max')
# Add the fake quant logical.
# The original model will be rewrite.
#
# There is only one Layer(self.linear1) would be added the
# fake quant logical.
imperative_qat.quantize(model)
# Fine-tune the quantized model
# ...
# Save quant model for the inference.
imperative_qat.save_quantized_model(
layer=model,
model_path="./imperative_model_qat")
""" """
super(ImperativeQuantAware, self).__init__() super(ImperativeQuantAware, self).__init__()
self._weight_bits = weight_bits self._weight_bits = weight_bits
...@@ -150,6 +201,7 @@ class ImperativeQuantAware(object): ...@@ -150,6 +201,7 @@ class ImperativeQuantAware(object):
self._act_pre_layer = act_preprocess_layer self._act_pre_layer = act_preprocess_layer
self._weight_quant_layer = weight_quantize_layer self._weight_quant_layer = weight_quantize_layer
self._act_quant_layer = act_quantize_layer self._act_quant_layer = act_quantize_layer
self._out_scale = ImperativeCalcOutScale()
t_check = lambda method: method is None or issubclass(method, dygraph.layers.Layer) t_check = lambda method: method is None or issubclass(method, dygraph.layers.Layer)
assert t_check( assert t_check(
...@@ -189,7 +241,7 @@ class ImperativeQuantAware(object): ...@@ -189,7 +241,7 @@ class ImperativeQuantAware(object):
""" """
According to weights' and activations' quantization types, the model will be added some fake According to weights' and activations' quantization types, the model will be added some fake
quant ops, such as fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_abs_max quant ops, such as fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_abs_max
and so on. and so on. At the same time, the out_scale value of outputs would be calculated.
Args: Args:
model(fluid.dygraph.Layer): the model to be quantized. model(fluid.dygraph.Layer): the model to be quantized.
...@@ -199,6 +251,9 @@ class ImperativeQuantAware(object): ...@@ -199,6 +251,9 @@ class ImperativeQuantAware(object):
for name, layer in model.named_sublayers(): for name, layer in model.named_sublayers():
if not isinstance(layer, self._quantizable_layer_type): if not isinstance(layer, self._quantizable_layer_type):
continue continue
if hasattr(layer, "skip_quant") and layer.skip_quant == True:
continue
scopes = name.split('.') scopes = name.split('.')
target = scopes[-1] target = scopes[-1]
obj = model obj = model
...@@ -210,6 +265,8 @@ class ImperativeQuantAware(object): ...@@ -210,6 +265,8 @@ class ImperativeQuantAware(object):
quant_layer = self._get_quantized_counterpart(layer) quant_layer = self._get_quantized_counterpart(layer)
setattr(obj, target, quant_layer) setattr(obj, target, quant_layer)
self._out_scale.calc_out_scale(model)
def _get_quantized_counterpart(self, layer): def _get_quantized_counterpart(self, layer):
quant_layers = tuple(self._quant_layers_map.values()) quant_layers = tuple(self._quant_layers_map.values())
quantized_counterpart = tuple('Quantized' + k quantized_counterpart = tuple('Quantized' + k
...@@ -233,47 +290,24 @@ class ImperativeQuantAware(object): ...@@ -233,47 +290,24 @@ class ImperativeQuantAware(object):
self._weight_quant_layer, self._act_quant_layer) self._weight_quant_layer, self._act_quant_layer)
return quantized_layer return quantized_layer
def save_quantized_model(self, layer, path, input_spec=None, **config):
self._out_scale.save_quantized_model(layer, path, input_spec, **config)
class ImperativeCalcOutScale(object): class ImperativeCalcOutScale(object):
def __init__(self, def __init__(self, moving_rate=0.9):
moving_rate=0.9,
target_layer_types=[
'BatchNorm', 'Conv2D', 'Conv2DTranspose', 'LeakyReLU',
'Linear', 'PReLU', 'Pool2D', 'ReLU', 'ReLU6', 'Sigmoid',
'Softmax', 'Tanh'
]):
""" """
Add the logic of calculating and setting output quantization scales of some layers. Add the logic of calculating and setting output quantization scales of some layers.
These output quantization scales may be used by tensorRT or some other inference engines. These output quantization scales may be used by tensorRT or some other inference engines.
Args: Args:
moving_rate(float): The decay coefficient of moving average. The default value is 0.9. moving_rate(float): The decay coefficient of moving average. The default value is 0.9.
quantizable_op_type(list[str]): List the type of layers that will be calculated out_scale.
Default is ['Conv2D', 'ReLU', 'PReLU', 'LeakyReLU', 'Linear', 'Sigmoid', 'BatchNorm', 'ReLU6', 'Tanh', 'Softmax', 'Conv2DTranspose']
""" """
super(ImperativeCalcOutScale, self).__init__() super(ImperativeCalcOutScale, self).__init__()
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._out_scale_layers_map = { self._out_scale_layer_type_list = (
'BatchNorm': BatchNorm, BatchNorm, Conv2D, Conv2DTranspose, LeakyReLU, Linear, PReLU,
'Conv2D': Conv2D, Pool2D, ReLU, ReLU6, Sigmoid, Softmax, Tanh, Swish)
'Conv2DTranspose': Conv2DTranspose,
'LeakyReLU': LeakyReLU,
'Linear': Linear,
'PReLU': PReLU,
'Pool2D': Pool2D,
'ReLU': ReLU,
'ReLU6': ReLU6,
'Sigmoid': Sigmoid,
'Softmax': Softmax,
'Tanh': Tanh
}
self._out_scale_layer_type = tuple(
self._out_scale_layers_map[layer]
if layer in self._out_scale_layers_map else layer
for layer in target_layer_types)
for layer in self._out_scale_layer_type:
assert not isinstance(
layer, str), "{} is unspported to be out_scaled.".format(layer)
self._register_hook_handle_list = [] self._register_hook_handle_list = []
self._out_scale_dict = {} self._out_scale_dict = {}
...@@ -290,26 +324,12 @@ class ImperativeCalcOutScale(object): ...@@ -290,26 +324,12 @@ class ImperativeCalcOutScale(object):
assert isinstance( assert isinstance(
model, dygraph.Layer), "model must be the instance of dygraph.Layer" model, dygraph.Layer), "model must be the instance of dygraph.Layer"
for _, layer in model.named_sublayers(): for _, layer in model.named_sublayers():
if not isinstance(layer, self._out_scale_layer_type): if not isinstance(layer, self._out_scale_layer_type_list):
continue continue
forward_post_hook_handle = layer.register_forward_post_hook( forward_post_hook_handle = layer.register_forward_post_hook(
self._forward_post_hook) self._forward_post_hook)
self._register_hook_handle_list.append(forward_post_hook_handle) self._register_hook_handle_list.append(forward_post_hook_handle)
# Get the output var name of the op
def _get_op_output_names(self, op):
assert isinstance(
op, framework.Operator), "The input op should be Operator."
var_names = []
name_list = _op_real_in_out_name[op.type][1]
for name in name_list:
var_name = op.output(name)
if isinstance(var_name, list):
var_names.extend(var_name)
else:
var_names.append(var_name)
return var_names
def save_quantized_model(self, layer, path, input_spec=None, **config): def save_quantized_model(self, layer, path, input_spec=None, **config):
""" """
Save the quantized model for the inference. Save the quantized model for the inference.
...@@ -335,6 +355,7 @@ class ImperativeCalcOutScale(object): ...@@ -335,6 +355,7 @@ class ImperativeCalcOutScale(object):
assert isinstance( assert isinstance(
layer, dygraph.Layer), "model must be the instance of dygraph.Layer" layer, dygraph.Layer), "model must be the instance of dygraph.Layer"
is_dynamic_mode = False
with dygraph.guard(): with dygraph.guard():
layer.eval() layer.eval()
for handle in self._register_hook_handle_list: for handle in self._register_hook_handle_list:
...@@ -345,6 +366,10 @@ class ImperativeCalcOutScale(object): ...@@ -345,6 +366,10 @@ class ImperativeCalcOutScale(object):
paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config) paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config)
if paddle.in_dynamic_mode():
is_dynamic_mode = True
paddle.enable_static()
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
else: else:
...@@ -369,7 +394,8 @@ class ImperativeCalcOutScale(object): ...@@ -369,7 +394,8 @@ class ImperativeCalcOutScale(object):
for block in inference_program.blocks: for block in inference_program.blocks:
for op in block.ops: for op in block.ops:
if op.type in _op_real_in_out_name: if op.type in _op_real_in_out_name:
output_var_names = self._get_op_output_names(op) output_var_names = quantization_pass._get_op_output_var_names(
op)
for output_var_name in output_var_names: for output_var_name in output_var_names:
output_var_tensor = block.var(output_var_name) output_var_tensor = block.var(output_var_name)
if output_var_tensor.dtype not in [ if output_var_tensor.dtype not in [
...@@ -386,6 +412,8 @@ class ImperativeCalcOutScale(object): ...@@ -386,6 +412,8 @@ class ImperativeCalcOutScale(object):
# to dygraph Layer by the name of output. And use dict to save # to dygraph Layer by the name of output. And use dict to save
# the corresponding relationship between the dygraph Layer and the # the corresponding relationship between the dygraph Layer and the
# static graph op that needs to set the outscale attribute. # static graph op that needs to set the outscale attribute.
if '.' not in output_var_name:
continue
dynamic_layer_name, var_name_suffix = output_var_name.split( dynamic_layer_name, var_name_suffix = output_var_name.split(
".") ".")
if dynamic_layer_name in layer_var_dict: if dynamic_layer_name in layer_var_dict:
...@@ -420,9 +448,12 @@ class ImperativeCalcOutScale(object): ...@@ -420,9 +448,12 @@ class ImperativeCalcOutScale(object):
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename) params_filename=params_filename)
if is_dynamic_mode:
paddle.disable_static()
def _forward_post_hook(self, layer, input, output): def _forward_post_hook(self, layer, input, output):
assert isinstance( assert isinstance(
output, core.VarBase output, (core.VarBase, framework.Variable)
), "Multiple outputs are not currently supported in ImperativeOutScale." ), "Multiple outputs are not currently supported in ImperativeOutScale."
if output.dtype not in [ if output.dtype not in [
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64 core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64
......
...@@ -25,12 +25,13 @@ import paddle.fluid.layers as layers ...@@ -25,12 +25,13 @@ import paddle.fluid.layers as layers
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.optimizer import AdamOptimizer from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import ImperativeCalcOutScale from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass, OutScaleForInferencePass from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass, OutScaleForInferencePass, QuantizationTransformPass
from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, ReLU6 from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, ReLU6
from paddle.fluid.dygraph.nn import BatchNorm, Conv2D, Linear, Pool2D from paddle.nn import Linear, Conv2D, Softmax, BatchNorm
from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
paddle.enable_static() paddle.enable_static()
...@@ -91,10 +92,10 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'): ...@@ -91,10 +92,10 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'):
sigmoid1 = layers.sigmoid(fc2) sigmoid1 = layers.sigmoid(fc2)
fc3 = fluid.layers.fc(input=sigmoid1, fc3 = fluid.layers.fc(input=sigmoid1,
size=num_classes, size=num_classes,
act=classifier_activation,
param_attr=fc_w3_attr, param_attr=fc_w3_attr,
bias_attr=fc_b3_attr) bias_attr=fc_b3_attr)
return fc3 softmax1 = layers.softmax(fc3, use_cudnn=True)
return softmax1
class ImperativeLenet(fluid.dygraph.Layer): class ImperativeLenet(fluid.dygraph.Layer):
...@@ -112,24 +113,24 @@ class ImperativeLenet(fluid.dygraph.Layer): ...@@ -112,24 +113,24 @@ class ImperativeLenet(fluid.dygraph.Layer):
fc_b3_attr = fluid.ParamAttr(name="fc_b_3") fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
self.features = Sequential( self.features = Sequential(
Conv2D( Conv2D(
num_channels=1, in_channels=1,
num_filters=6, out_channels=6,
filter_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
param_attr=conv2d_w1_attr, weight_attr=conv2d_w1_attr,
bias_attr=conv2d_b1_attr), bias_attr=conv2d_b1_attr),
BatchNorm(6), BatchNorm(6),
ReLU(), ReLU(),
Pool2D( Pool2D(
pool_size=2, pool_type='max', pool_stride=2), pool_size=2, pool_type='max', pool_stride=2),
Conv2D( Conv2D(
num_channels=6, in_channels=6,
num_filters=16, out_channels=16,
filter_size=5, kernel_size=5,
stride=1, stride=1,
padding=0, padding=0,
param_attr=conv2d_w2_attr, weight_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr), bias_attr=conv2d_b2_attr),
BatchNorm(16), BatchNorm(16),
ReLU6(), ReLU6(),
...@@ -138,23 +139,23 @@ class ImperativeLenet(fluid.dygraph.Layer): ...@@ -138,23 +139,23 @@ class ImperativeLenet(fluid.dygraph.Layer):
self.fc = Sequential( self.fc = Sequential(
Linear( Linear(
input_dim=400, in_features=400,
output_dim=120, out_features=120,
param_attr=fc_w1_attr, weight_attr=fc_w1_attr,
bias_attr=fc_b1_attr), bias_attr=fc_b1_attr),
LeakyReLU(), LeakyReLU(),
Linear( Linear(
input_dim=120, in_features=120,
output_dim=84, out_features=84,
param_attr=fc_w2_attr, weight_attr=fc_w2_attr,
bias_attr=fc_b2_attr), bias_attr=fc_b2_attr),
Sigmoid(), Sigmoid(),
Linear( Linear(
input_dim=84, in_features=84,
act=classifier_activation, out_features=num_classes,
output_dim=num_classes, weight_attr=fc_w3_attr,
param_attr=fc_w3_attr, bias_attr=fc_b3_attr),
bias_attr=fc_b3_attr)) Softmax())
def forward(self, inputs): def forward(self, inputs):
x = self.features(inputs) x = self.features(inputs)
...@@ -165,105 +166,6 @@ class ImperativeLenet(fluid.dygraph.Layer): ...@@ -165,105 +166,6 @@ class ImperativeLenet(fluid.dygraph.Layer):
class TestImperativeOutSclae(unittest.TestCase): class TestImperativeOutSclae(unittest.TestCase):
def test_calc_out_scale_save(self):
imperative_out_scale = ImperativeCalcOutScale()
with fluid.dygraph.guard():
lenet = ImperativeLenet()
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=lenet.parameters())
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=32, drop_last=True)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32)
imperative_out_scale.calc_out_scale(lenet)
epoch_num = 1
for epoch in range(epoch_num):
lenet.train()
for batch_id, data in enumerate(train_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
acc = fluid.layers.accuracy(out, label)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
lenet.clear_gradients()
if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".
format(epoch, batch_id,
avg_loss.numpy(), acc.numpy()))
lenet.eval()
for batch_id, data in enumerate(test_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
acc_top1 = fluid.layers.accuracy(
input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(
input=out, label=label, k=5)
if batch_id % 100 == 0:
_logger.info(
"Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}".
format(epoch, batch_id,
acc_top1.numpy(), acc_top5.numpy()))
# save weights
model_dict = lenet.state_dict()
fluid.save_dygraph(model_dict, "save_temp")
# test the correctness of `save_quantized_model`
data = next(test_reader())
test_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
test_img = fluid.dygraph.to_variable(test_data)
lenet.eval()
before_save = lenet(test_img)
# save inference quantized model
path = "./outscale_infer_model/lenet"
save_dir = "./outscale_infer_model"
imperative_out_scale.save_quantized_model(
layer=lenet,
path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
[inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=save_dir,
executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX))
after_save, = exe.run(inference_program,
feed={feed_target_names[0]: test_data},
fetch_list=fetch_targets)
self.assertTrue(
np.allclose(after_save, before_save.numpy()),
msg='Failed to save the inference quantized model.')
def test_out_scale_acc(self): def test_out_scale_acc(self):
def _build_static_lenet(main, startup, is_test=False, seed=1000): def _build_static_lenet(main, startup, is_test=False, seed=1000):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
...@@ -285,6 +187,8 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -285,6 +187,8 @@ class TestImperativeOutSclae(unittest.TestCase):
reader = paddle.batch( reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32, drop_last=True) paddle.dataset.mnist.test(), batch_size=32, drop_last=True)
weight_quantize_type = 'abs_max'
activation_quant_type = 'moving_average_abs_max'
param_init_map = {} param_init_map = {}
seed = 1000 seed = 1000
lr = 0.1 lr = 0.1
...@@ -295,7 +199,7 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -295,7 +199,7 @@ class TestImperativeOutSclae(unittest.TestCase):
_logger.info( _logger.info(
"--------------------------dynamic graph qat--------------------------" "--------------------------dynamic graph qat--------------------------"
) )
imperative_out_scale = ImperativeCalcOutScale() imperative_out_scale = ImperativeQuantAware()
with fluid.dygraph.guard(): with fluid.dygraph.guard():
np.random.seed(seed) np.random.seed(seed)
...@@ -315,7 +219,7 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -315,7 +219,7 @@ class TestImperativeOutSclae(unittest.TestCase):
fixed_state[name] = value fixed_state[name] = value
param_init_map[param.name] = value param_init_map[param.name] = value
lenet.set_dict(fixed_state) lenet.set_dict(fixed_state)
imperative_out_scale.calc_out_scale(lenet) imperative_out_scale.quantize(lenet)
adam = AdamOptimizer( adam = AdamOptimizer(
learning_rate=lr, parameter_list=lenet.parameters()) learning_rate=lr, parameter_list=lenet.parameters())
dynamic_loss_rec = [] dynamic_loss_rec = []
...@@ -340,11 +244,9 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -340,11 +244,9 @@ class TestImperativeOutSclae(unittest.TestCase):
_logger.info('{}: {}'.format('loss', avg_loss.numpy())) _logger.info('{}: {}'.format('loss', avg_loss.numpy()))
lenet.eval() lenet.eval()
op_object_list = (Conv2D, ReLU, ReLU6, LeakyReLU, Sigmoid, Pool2D,
BatchNorm)
path = "./dynamic_outscale_infer_model/lenet" path = "./dynamic_outscale_infer_model/lenet"
save_dir = "./dynamic_outscale_infer_model" dynamic_save_dir = "./dynamic_outscale_infer_model"
imperative_out_scale.save_quantized_model( imperative_out_scale.save_quantized_model(
layer=lenet, layer=lenet,
...@@ -384,8 +286,16 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -384,8 +286,16 @@ class TestImperativeOutSclae(unittest.TestCase):
param_tensor.set(param_init_map[param.name], place) param_tensor.set(param_init_map[param.name], place)
main_graph = IrGraph(core.Graph(main.desc), for_test=False) main_graph = IrGraph(core.Graph(main.desc), for_test=False)
infer_graph = IrGraph(core.Graph(infer.desc), for_test=True) infer_graph = IrGraph(core.Graph(infer.desc), for_test=True)
transform_pass = OutScaleForTrainingPass(scope=scope, place=place) transform_pass = QuantizationTransformPass(
scope=scope,
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'])
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(infer_graph)
outscale_pass = OutScaleForTrainingPass(scope=scope, place=place)
outscale_pass.apply(main_graph)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False build_strategy.fuse_all_reduce_ops = False
binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel( binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
...@@ -404,20 +314,18 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -404,20 +314,18 @@ class TestImperativeOutSclae(unittest.TestCase):
scale_inference_pass = OutScaleForInferencePass(scope=scope) scale_inference_pass = OutScaleForInferencePass(scope=scope)
scale_inference_pass.apply(infer_graph) scale_inference_pass.apply(infer_graph)
out_scale_op_list = [
"batch_norm", "conv2d", "leaky_relu", "pool2d", "relu6", "relu",
"sigmoid", "tanh", "relu6", "softmax", "conv2d_transpose",
"elementwise_add"
]
op_nodes = infer_graph.all_op_nodes()
for op_node in op_nodes:
if op_node.name() in out_scale_op_list:
static_out_scale_list.append(op_node.op().attr("out_threshold"))
save_program = infer_graph.to_program() save_program = infer_graph.to_program()
static_save_dir = "./static_outscale_infer_model"
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
fluid.io.save_inference_model("./static_mnist", [infer_img.name], fluid.io.save_inference_model(
[infer_pre], exe, save_program) dirname=static_save_dir,
feeded_var_names=[infer_img.name],
target_vars=[infer_pre],
executor=exe,
main_program=save_program,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX)
rtol = 1e-05 rtol = 1e-05
atol = 1e-08 atol = 1e-08
for i, (loss_d, for i, (loss_d,
...@@ -437,24 +345,38 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -437,24 +345,38 @@ class TestImperativeOutSclae(unittest.TestCase):
atol=atol, atol=atol,
equal_nan=True), equal_nan=True),
msg='Failed to do the imperative qat.') msg='Failed to do the imperative qat.')
# load dynamic model # load dynamic model
[inference_program, feed_target_names, fetch_targets] = ( [dynamic_inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model( fluid.io.load_inference_model(
dirname=save_dir, dirname=dynamic_save_dir,
executor=exe, executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX, model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX)) params_filename="lenet" + INFER_PARAMS_SUFFIX))
# load static model
[static_inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=static_save_dir,
executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX))
dynamic_ops = dynamic_inference_program.global_block().ops
static_ops = static_inference_program.global_block().ops
for op in dynamic_ops[:]:
if op.type == "flatten2" or 'fake' in op.type:
dynamic_ops.remove(op)
global_block = inference_program.global_block() for op in static_ops[:]:
for op in global_block.ops: if 'fake' in op.type:
if op.has_attr('out_threshold'): static_ops.remove(op)
dynamic_out_scale_list.append(op.attr('out_threshold'))
check_list = [ for i in range(len(dynamic_ops)):
False for item in dynamic_out_scale_list if dynamic_ops[i].has_attr("out_threshold"):
if item not in static_out_scale_list self.assertTrue(dynamic_ops[i].type == static_ops[i].type)
] self.assertTrue(dynamic_ops[i].attr("out_threshold") ==
self.assertTrue(len(check_list) == 0) static_ops[i].attr("out_threshold"))
if __name__ == '__main__': if __name__ == '__main__':
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import os
import numpy as np
import random
import unittest
import logging
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid import core
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, ReLU6
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm
from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.log_helper import get_logger
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
quant_skip_pattern_list = ['skip_qat', 'skip_quant']
class ImperativeLenet(fluid.dygraph.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'):
super(ImperativeLenet, self).__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
self.conv2d_0 = Conv2D(
in_channels=1,
out_channels=6,
kernel_size=3,
stride=1,
padding=1,
weight_attr=conv2d_w1_attr,
bias_attr=conv2d_b1_attr)
self.conv2d_0.skip_quant = True
self.batch_norm_0 = BatchNorm(6)
self.relu_0 = ReLU()
self.pool2d_0 = Pool2D(pool_size=2, pool_type='max', pool_stride=2)
self.conv2d_1 = Conv2D(
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0,
weight_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr)
self.conv2d_1.skip_quant = False
self.batch_norm_1 = BatchNorm(16)
self.relu6_0 = ReLU6()
self.pool2d_1 = Pool2D(pool_size=2, pool_type='max', pool_stride=2)
self.linear_0 = Linear(
in_features=400,
out_features=120,
weight_attr=fc_w1_attr,
bias_attr=fc_b1_attr)
self.linear_0.skip_quant = True
self.leaky_relu_0 = LeakyReLU()
self.linear_1 = Linear(
in_features=120,
out_features=84,
weight_attr=fc_w2_attr,
bias_attr=fc_b2_attr)
self.linear_1.skip_quant = False
self.sigmoid_0 = Sigmoid()
self.linear_2 = Linear(
in_features=84,
out_features=num_classes,
weight_attr=fc_w3_attr,
bias_attr=fc_b3_attr)
self.linear_2.skip_quant = False
self.softmax_0 = Softmax()
def forward(self, inputs):
x = self.conv2d_0(inputs)
x = self.batch_norm_0(x)
x = self.relu_0(x)
x = self.pool2d_0(x)
x = self.conv2d_1(x)
x = self.batch_norm_1(x)
x = self.relu6_0(x)
x = self.pool2d_1(x)
x = fluid.layers.flatten(x, 1)
x = self.linear_0(x)
x = self.leaky_relu_0(x)
x = self.linear_1(x)
x = self.sigmoid_0(x)
x = self.linear_2(x)
x = self.softmax_0(x)
return x
class TestImperativeOutSclae(unittest.TestCase):
def test_out_scale_acc(self):
seed = 1000
lr = 0.1
imperative_out_scale = ImperativeQuantAware()
np.random.seed(seed)
reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32, drop_last=True)
lenet = ImperativeLenet()
fixed_state = {}
for name, param in lenet.named_parameters():
p_shape = param.numpy().shape
p_value = param.numpy()
if name.endswith("bias"):
value = np.zeros_like(p_value).astype('float32')
else:
value = np.random.normal(
loc=0.0, scale=0.01,
size=np.product(p_shape)).reshape(p_shape).astype('float32')
fixed_state[name] = value
lenet.set_dict(fixed_state)
imperative_out_scale.quantize(lenet)
adam = AdamOptimizer(
learning_rate=lr, parameter_list=lenet.parameters())
dynamic_loss_rec = []
lenet.train()
for batch_id, data in enumerate(reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
lenet.clear_gradients()
dynamic_loss_rec.append(avg_loss.numpy()[0])
if batch_id % 100 == 0:
_logger.info('{}: {}'.format('loss', avg_loss.numpy()))
lenet.eval()
path = "./save_dynamic_quant_infer_model/lenet"
save_dir = "./save_dynamic_quant_infer_model"
imperative_out_scale.save_quantized_model(
layer=lenet,
path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
paddle.enable_static()
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
[inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=save_dir,
executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX))
model_ops = inference_program.global_block().ops
conv2d_count, mul_count = 0, 0
for i, op in enumerate(model_ops):
if op.type == 'conv2d':
if conv2d_count > 0:
self.assertTrue(
'fake_quantize_dequantize' in model_ops[i - 1].type)
else:
self.assertTrue(
'fake_quantize_dequantize' not in model_ops[i - 1].type)
conv2d_count += 1
if op.type == 'mul':
if mul_count > 0:
self.assertTrue(
'fake_quantize_dequantize' in model_ops[i - 1].type)
else:
self.assertTrue(
'fake_quantize_dequantize' not in model_ops[i - 1].type)
mul_count += 1
if __name__ == '__main__':
unittest.main()
...@@ -73,7 +73,7 @@ class TestMovingAverageAbsMaxScaleOp(unittest.TestCase): ...@@ -73,7 +73,7 @@ class TestMovingAverageAbsMaxScaleOp(unittest.TestCase):
feed_dict = {"image": img, "label": label} feed_dict = {"image": img, "label": label}
res = exe.run(binary, feed_dict) res = exe.run(binary, feed_dict)
def test_fw_bw(self): def test_check_op_times(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.check_backward(use_cuda=True) self.check_backward(use_cuda=True)
self.check_backward(use_cuda=False) self.check_backward(use_cuda=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册