未验证 提交 ef0dd3ef 编写于 作者: G guofei 提交者: GitHub

Support loading parameters from checkpoint to save quantized model (#31419)

* Support loading parameters from checkpoint to save quantized model

* Fix the unittest test_moving_average_abs_max_scale_op

* Add unittest of save_quantized_model from checkpoint

* Add comments to explain the function
上级 da9dda5c
...@@ -17,11 +17,15 @@ import logging ...@@ -17,11 +17,15 @@ import logging
import numpy as np import numpy as np
import sys import sys
import os import os
import warnings
import paddle import paddle
from paddle.fluid import dygraph, core, framework from paddle.fluid import dygraph, core, framework, unique_name
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn import Linear, Conv2D, Conv2DTranspose, MaxPool2D, MaxPool1D, BatchNorm1D, BatchNorm2D, BatchNorm3D from paddle.nn import Linear, Conv2D, Conv2DTranspose, MaxPool2D, MaxPool1D, BatchNorm1D, BatchNorm2D, BatchNorm3D, SyncBatchNorm
from paddle.fluid.dygraph.nn import BatchNorm, Pool2D from paddle.fluid.dygraph.nn import BatchNorm, Pool2D
from paddle.fluid.io import load_inference_model, save_inference_model from paddle.fluid.io import load_inference_model, save_inference_model
from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6, Tanh, Softmax, PReLU, Swish from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6, Tanh, Softmax, PReLU, Swish
...@@ -331,10 +335,73 @@ class ImperativeCalcOutScale(object): ...@@ -331,10 +335,73 @@ class ImperativeCalcOutScale(object):
self._out_scale_layer_type_list = ( self._out_scale_layer_type_list = (
BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, LeakyReLU, BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, LeakyReLU,
Linear, PReLU, Pool2D, MaxPool1D, MaxPool2D, ReLU, ReLU6, Sigmoid, Linear, PReLU, Pool2D, MaxPool1D, MaxPool2D, ReLU, ReLU6, Sigmoid,
Softmax, Tanh, Swish) Softmax, SyncBatchNorm, Tanh, Swish)
self._register_hook_handle_list = [] self._register_hook_handle_list = []
self._out_scale_dict = collections.OrderedDict() self._out_scale_dict = collections.OrderedDict()
# Determine whether layer supports calculation out_scale
def _is_matched_layer(self, layer):
if not isinstance(layer, self._out_scale_layer_type_list):
if 'quantized_' not in layer.full_name():
return False
return True
# When inferenc model is saved, the logic in hook would not be executed
# in program translation, so that some parameters can not created in
# __init__, which would cause the model to fail to save. Therefore, the
# parameters creation in the hook is advanced to be exected outside the hook.
def _add_new_parameters(self, layer, name=None):
dtype = layer._dtype if layer._dtype is not None else "float32"
if dtype not in ["float32", "float64"]:
return
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
scale_name = unique_name.generate(scale_prefix)
scale_attr = ParamAttr(
name=scale_name, initializer=Constant(1), trainable=False)
layer._quant_out_scale = layer.create_parameter(
shape=[1], attr=scale_attr, dtype=dtype)
layer._quant_out_scale.stop_gradient = True
state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
layer._quant_out_state = layer.create_parameter(
shape=[1], attr=state_attr, dtype=dtype)
layer._quant_out_state.stop_gradient = True
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
layer._quant_out_accum = layer.create_parameter(
shape=[1], attr=accum_attr, dtype=dtype)
layer._quant_out_accum.stop_gradient = True
# Judge whether the op in program matches the Layer in dynamic model
def _is_op_matched(self, layer_name, op, block):
output_var_names = quantization_pass._get_op_output_var_names(op)
for output_var_name in output_var_names:
output_var_tensor = block.var(output_var_name)
if output_var_tensor.dtype not in [
core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32
]:
return False
# Because the naming styles of static and dynamic graph are different,
# in order to avoid mistakes, we unify the name here.
op_type = output_var_names[0].split(".")[0]
op_type = op_type.rsplit("_", 1)[0]
if op_type == 'depthwise_conv2d':
op_type = 'conv2d'
if 'prelu' in op_type:
op_type = op_type.replace('prelu', 'p_re_lu')
if 'relu' in op_type:
op_type = op_type.replace('relu', 're_lu')
return op_type in layer_name
def calc_out_scale(self, model): def calc_out_scale(self, model):
""" """
Insert the `moving_average_abs_max_scale` op to calculate output scale of Specific layers in model. Insert the `moving_average_abs_max_scale` op to calculate output scale of Specific layers in model.
...@@ -348,9 +415,8 @@ class ImperativeCalcOutScale(object): ...@@ -348,9 +415,8 @@ class ImperativeCalcOutScale(object):
assert isinstance( assert isinstance(
model, dygraph.Layer), "model must be the instance of dygraph.Layer" model, dygraph.Layer), "model must be the instance of dygraph.Layer"
for _, layer in model.named_sublayers(): for _, layer in model.named_sublayers():
if not isinstance(layer, self._out_scale_layer_type_list): if self._is_matched_layer(layer):
if 'quantized_' not in layer.full_name(): self._add_new_parameters(layer)
continue
forward_post_hook_handle = layer.register_forward_post_hook( forward_post_hook_handle = layer.register_forward_post_hook(
self._forward_post_hook) self._forward_post_hook)
self._register_hook_handle_list.append(forward_post_hook_handle) self._register_hook_handle_list.append(forward_post_hook_handle)
...@@ -380,14 +446,26 @@ class ImperativeCalcOutScale(object): ...@@ -380,14 +446,26 @@ class ImperativeCalcOutScale(object):
assert isinstance( assert isinstance(
layer, dygraph.Layer), "model must be the instance of dygraph.Layer" layer, dygraph.Layer), "model must be the instance of dygraph.Layer"
self._layer = layer
is_dynamic_mode = False is_dynamic_mode = False
with dygraph.guard(): with dygraph.guard():
layer.eval() self._layer.eval()
if self._register_hook_handle_list is not None:
for handle in self._register_hook_handle_list: for handle in self._register_hook_handle_list:
handle.remove() handle.remove()
if self._out_scale_dict:
for key in self._out_scale_dict: for key in self._out_scale_dict:
self._out_scale_dict[key] = float(self._out_scale_dict[key] self._out_scale_dict[key] = float(self._out_scale_dict[key]
.numpy()) .numpy())
else:
for _, sub_layer in self._layer.named_sublayers():
if self._is_matched_layer(sub_layer):
layer_name = sub_layer.full_name()
if hasattr(sub_layer, "layer_name"):
layer_name = sub_layer.layer_name
if hasattr(sub_layer, "_quant_out_scale"):
self._out_scale_dict[layer_name] = float(
sub_layer._quant_out_scale)
if paddle.in_dynamic_mode(): if paddle.in_dynamic_mode():
is_dynamic_mode = True is_dynamic_mode = True
...@@ -413,74 +491,68 @@ class ImperativeCalcOutScale(object): ...@@ -413,74 +491,68 @@ class ImperativeCalcOutScale(object):
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename)) params_filename=params_filename))
# Traverse all ops in the program and find out the op matching check_behind_op = False
# the Layer in the dynamic graph.
layer_var_dict = collections.OrderedDict()
ops_list = [key for key, _ in self._out_scale_dict.items()]
op_count = 0 op_count = 0
conv_count = 0 ops_list = [key for key, _ in self._out_scale_dict.items()]
if len(ops_list) == 0:
for block in inference_program.blocks: warnings.warn(
for op in block.ops: "Warning: No Layer of the model while to be saved contains the out_threshold attribute, "
if op.type in _op_real_in_out_name: "so the generated inference model would not contain the out_threshold."
if op.type in ["batch_norm", "pool2d"]: )
if op.type == "pool2d" and op.attr(
"pooling_type") != "max":
continue
op_count = self.op_match(op, ops_list, op_count)
if op_count >= len(ops_list):
continue
op._set_attr('out_threshold',
self._out_scale_dict[ops_list[op_count]])
op_count += 1
else: else:
output_var_names = quantization_pass._get_op_output_var_names(
op)
for output_var_name in output_var_names:
output_var_tensor = block.var(output_var_name)
if output_var_tensor.dtype not in [
core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32
]:
continue
# Because the Layer in dygraph may correspond to multiple ops # Because the Layer in dygraph may correspond to multiple ops
# in static program after being saved. To ensure correctness, # in static program after being saved. To ensure correctness,
# the outscale collected for output of dygraph Layer can only # the outscale collected for output of dygraph Layer can only
# be set to the last op in the corresponding ops in static program. # be set to the last op in the corresponding ops in static program.
# #
# We can judge the execution order of the ops which corresponding # We can judge the execution order of the ops which corresponding
# to dygraph Layer by the name of output. And use dict to save # to dygraph Layer by check_behind_op
# the corresponding relationship between the dygraph Layer and the forward_op = None
# static graph op that needs to set the outscale attribute. for block in inference_program.blocks:
if '.' not in output_var_name: for op in block.ops:
if op.type in _op_real_in_out_name:
if op_count > len(ops_list):
warnings.warn(
"The number of Layer which has out_threshold attribute should be bigger than the op in inference model"
)
break
if check_behind_op:
check_behind_op = False
if op.type == "elementwise_add":
if self._is_op_matched(ops_list[op_count], op,
block):
op._set_attr("out_threshold",
self._out_scale_dict[ops_list[
op_count]])
op_count += 1
forward_op = None
continue continue
dynamic_layer_name, var_name_suffix = output_var_name.split(
".")
if dynamic_layer_name in layer_var_dict:
if layer_var_dict[dynamic_layer_name][
0] < var_name_suffix:
layer_var_dict[dynamic_layer_name] = [
var_name_suffix, op
]
else: else:
layer_var_dict[dynamic_layer_name] = [ if forward_op is None:
var_name_suffix, op raise ValueError(
] "forward_op should not be None")
if self._is_op_matched(ops_list[op_count],
forward_op, block):
forward_op._set_attr(
"out_threshold", self._out_scale_dict[
ops_list[op_count]])
op_count += 1
forward_op = None
# Because the naming styles of static and dynamic graph are different, if op.type in ["conv2d", "depthwise_conv2d", "matmul"]:
# in order to avoid mistakes, we unify the name here. check_behind_op = True
for (layer_name, var_name_op_list) in layer_var_dict.items(): forward_op = op
if 'prelu' in layer_name:
layer_name = layer_name.replace('prelu', 'p_re_lu')
if 'relu' in layer_name:
layer_name = layer_name.replace('relu', 're_lu')
if 'conv2d' in layer_name:
layer_name = 'conv2d_' + str(conv_count)
conv_count = conv_count + 1
if layer_name not in self._out_scale_dict:
continue continue
var_name_op_list[1]._set_attr('out_threshold', if op_count >= len(ops_list):
self._out_scale_dict[layer_name]) warnings.warn(
"The number of Layer which has out_threshold attribute should be bigger than the op in inference model"
)
break
if self._is_op_matched(ops_list[op_count], op, block):
op._set_attr(
"out_threshold",
self._out_scale_dict[ops_list[op_count]])
op_count += 1
# Save the processed program. # Save the processed program.
save_inference_model( save_inference_model(
...@@ -495,14 +567,6 @@ class ImperativeCalcOutScale(object): ...@@ -495,14 +567,6 @@ class ImperativeCalcOutScale(object):
if is_dynamic_mode: if is_dynamic_mode:
paddle.disable_static() paddle.disable_static()
def op_match(self, op, ops_list, op_count):
while op_count < len(ops_list) and op.type not in ops_list[op_count]:
op_count += 1
while op_count < len(ops_list) and op.type is "pool2d" and op.attr(
"pooling_type") != "max":
op_count += 1
return op_count
def _forward_post_hook(self, layer, input, output): def _forward_post_hook(self, layer, input, output):
assert isinstance( assert isinstance(
output, (core.VarBase, framework.Variable) output, (core.VarBase, framework.Variable)
...@@ -512,9 +576,9 @@ class ImperativeCalcOutScale(object): ...@@ -512,9 +576,9 @@ class ImperativeCalcOutScale(object):
]: ]:
return return
if not hasattr(layer, "_out_scale"): if not hasattr(layer, "_out_scale"):
layer._out_scale = quant_nn.MovingAverageAbsMaxScale( self._out_scale = quant_nn.MovingAverageAbsMaxScale(
output.name, self._moving_rate, output.dtype) layer, output.name, self._moving_rate, output.dtype)
scale_out = layer._out_scale(output) scale_out = self._out_scale(output)
if hasattr(layer, 'layer_name'): if hasattr(layer, 'layer_name'):
layer_name = layer.layer_name layer_name = layer.layer_name
else: else:
......
...@@ -503,7 +503,7 @@ class QuantizedNoweightLayer(layers.Layer): ...@@ -503,7 +503,7 @@ class QuantizedNoweightLayer(layers.Layer):
class MovingAverageAbsMaxScale(layers.Layer): class MovingAverageAbsMaxScale(layers.Layer):
def __init__(self, name=None, moving_rate=0.9, dtype='float32'): def __init__(self, layer=None, name=None, moving_rate=0.9, dtype='float32'):
r""" r"""
MovingAverageMaxScale layer is used to calculating the output quantization scale of Layer. MovingAverageMaxScale layer is used to calculating the output quantization scale of Layer.
Its computational formula is described as below: Its computational formula is described as below:
...@@ -514,15 +514,22 @@ class MovingAverageAbsMaxScale(layers.Layer): ...@@ -514,15 +514,22 @@ class MovingAverageAbsMaxScale(layers.Layer):
super(MovingAverageAbsMaxScale, self).__init__() super(MovingAverageAbsMaxScale, self).__init__()
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._dtype = dtype self._dtype = dtype
self._layer = layer
if self._layer is None or not hasattr(self._layer, "_quant_out_scale"):
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale' scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
name = unique_name.generate(scale_prefix) scale_name = unique_name.generate(scale_prefix)
scale_attr = ParamAttr( scale_attr = ParamAttr(
name=name, initializer=Constant(1), trainable=False) name=scale_name, initializer=Constant(1), trainable=False)
self._scale = self.create_parameter( self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=self._dtype) shape=[1], attr=scale_attr, dtype=self._dtype)
self._scale.stop_gradient = True self._scale.stop_gradient = True
if self._layer is not None:
setattr(self._layer, "_quant_out_scale", self._scale)
else:
self._scale = self._layer._quant_out_scale
if self._layer is None or not hasattr(self._layer, "_quant_out_state"):
state_prefix = "{}.state".format(name) if name else 'outscale.state' state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr( state_attr = ParamAttr(
name=unique_name.generate(state_prefix), name=unique_name.generate(state_prefix),
...@@ -531,7 +538,12 @@ class MovingAverageAbsMaxScale(layers.Layer): ...@@ -531,7 +538,12 @@ class MovingAverageAbsMaxScale(layers.Layer):
self._state = self.create_parameter( self._state = self.create_parameter(
shape=[1], attr=state_attr, dtype=self._dtype) shape=[1], attr=state_attr, dtype=self._dtype)
self._state.stop_gradient = True self._state.stop_gradient = True
if self._layer is not None:
setattr(self._layer, "_quant_out_state", self._state)
else:
self._state = self._layer._quant_out_state
if self._layer is None or not hasattr(self._layer, "_quant_out_accum"):
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum' accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr( accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix), name=unique_name.generate(accum_prefix),
...@@ -540,7 +552,10 @@ class MovingAverageAbsMaxScale(layers.Layer): ...@@ -540,7 +552,10 @@ class MovingAverageAbsMaxScale(layers.Layer):
self._accum = self.create_parameter( self._accum = self.create_parameter(
shape=[1], attr=accum_attr, dtype=self._dtype) shape=[1], attr=accum_attr, dtype=self._dtype)
self._accum.stop_gradient = True self._accum.stop_gradient = True
MovingAverageAbsMaxScale._has_create = True if self._layer is not None:
setattr(self._layer, "_quant_out_accum", self._accum)
else:
self._accum = self._layer._quant_out_accum
def forward(self, input): def forward(self, input):
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -549,18 +564,17 @@ class MovingAverageAbsMaxScale(layers.Layer): ...@@ -549,18 +564,17 @@ class MovingAverageAbsMaxScale(layers.Layer):
state = self._state if self.training else None state = self._state if self.training else None
accum = self._accum if self.training else None accum = self._accum if self.training else None
out_scale, _, _ = core.ops.moving_average_abs_max_scale( self._scale, _, _ = core.ops.moving_average_abs_max_scale(
input, accum, state, self._scale, state, accum, *attrs) input, accum, state, self._scale, state, accum, *attrs)
return out_scale return self._scale
check_variable_and_dtype(input, 'input', ['float32', 'float64'], check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'MovingAverageAbsMaxScale') 'MovingAverageAbsMaxScale')
scale_out = self._scale
attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training} attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training}
inputs = {"X": [input]} inputs = {"X": [input]}
outputs = {"OutScale": [scale_out]} outputs = {"OutScale": [self._scale]}
if self.training: if self.training:
inputs['InState'] = [self._state] inputs['InState'] = [self._state]
...@@ -574,4 +588,4 @@ class MovingAverageAbsMaxScale(layers.Layer): ...@@ -574,4 +588,4 @@ class MovingAverageAbsMaxScale(layers.Layer):
outputs=outputs, outputs=outputs,
attrs=attrs) attrs=attrs)
return scale_out return self._scale
...@@ -19,6 +19,8 @@ import numpy as np ...@@ -19,6 +19,8 @@ import numpy as np
import random import random
import unittest import unittest
import logging import logging
import warnings
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
...@@ -29,7 +31,7 @@ from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware ...@@ -29,7 +31,7 @@ from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass, OutScaleForInferencePass, QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass, OutScaleForInferencePass, QuantizationTransformPass
from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, ReLU6 from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, PReLU
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D
from paddle.fluid.dygraph.nn import Pool2D from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
...@@ -45,6 +47,14 @@ _logger = get_logger( ...@@ -45,6 +47,14 @@ _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def get_vaild_warning_num(warning, w):
num = 0
for i in range(len(w)):
if warning in str(w[i].message):
num += 1
return num
def StaticLenet(data, num_classes=10, classifier_activation='softmax'): def StaticLenet(data, num_classes=10, classifier_activation='softmax'):
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
...@@ -76,9 +86,9 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'): ...@@ -76,9 +86,9 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'):
param_attr=conv2d_w2_attr, param_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr) bias_attr=conv2d_b2_attr)
batch_norm2 = layers.batch_norm(conv2) batch_norm2 = layers.batch_norm(conv2)
relu6_1 = layers.relu6(batch_norm2) prelu1 = layers.prelu(batch_norm2, mode='all')
pool2 = fluid.layers.pool2d( pool2 = fluid.layers.pool2d(
relu6_1, pool_size=2, pool_type='max', pool_stride=2) prelu1, pool_size=2, pool_type='max', pool_stride=2)
fc1 = fluid.layers.fc(input=pool2, fc1 = fluid.layers.fc(input=pool2,
size=120, size=120,
...@@ -132,7 +142,7 @@ class ImperativeLenet(fluid.dygraph.Layer): ...@@ -132,7 +142,7 @@ class ImperativeLenet(fluid.dygraph.Layer):
weight_attr=conv2d_w2_attr, weight_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr), bias_attr=conv2d_b2_attr),
BatchNorm2D(16), BatchNorm2D(16),
ReLU6(), PReLU(),
MaxPool2D( MaxPool2D(
kernel_size=2, stride=2)) kernel_size=2, stride=2))
...@@ -246,6 +256,10 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -246,6 +256,10 @@ class TestImperativeOutSclae(unittest.TestCase):
lenet.eval() lenet.eval()
param_save_path = "test_save_quantized_model/lenet.pdparams"
save_dict = lenet.state_dict()
paddle.save(save_dict, param_save_path)
path = "./dynamic_outscale_infer_model/lenet" path = "./dynamic_outscale_infer_model/lenet"
dynamic_save_dir = "./dynamic_outscale_infer_model" dynamic_save_dir = "./dynamic_outscale_infer_model"
...@@ -285,6 +299,8 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -285,6 +299,8 @@ class TestImperativeOutSclae(unittest.TestCase):
for param in main.all_parameters(): for param in main.all_parameters():
if "batch_norm" in param.name: if "batch_norm" in param.name:
param_name = param.name.replace("norm", "norm2d") param_name = param.name.replace("norm", "norm2d")
elif 'prelu' in param.name:
param_name = param.name.replace("prelu", 'p_re_lu')
else: else:
param_name = param.name param_name = param.name
param_tensor = scope.var(param.name).get_tensor() param_tensor = scope.var(param.name).get_tensor()
...@@ -384,5 +400,94 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -384,5 +400,94 @@ class TestImperativeOutSclae(unittest.TestCase):
static_ops[i].attr("out_threshold")) static_ops[i].attr("out_threshold"))
class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
def test_save_quantized_model(self):
weight_quantize_type = 'abs_max'
activation_quantize_type = 'moving_average_abs_max'
load_param_path = "test_save_quantized_model/lenet.pdparams"
path = "./dynamic_outscale_infer_model_from_checkpoint/lenet"
dynamic_model_save_dir = "./dynamic_outscale_infer_model_from_checkpoint"
static_model_save_dir = "./static_outscale_infer_model"
imperative_out_scale = ImperativeQuantAware(
weight_quantize_type=weight_quantize_type,
activation_quantize_type=activation_quantize_type)
with fluid.dygraph.guard():
lenet = ImperativeLenet()
load_dict = paddle.load(load_param_path)
imperative_out_scale.quantize(lenet)
lenet.set_dict(load_dict)
imperative_out_scale.save_quantized_model(
layer=lenet,
path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
# load dynamic model
[dynamic_inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=dynamic_model_save_dir,
executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX))
# load static model
[static_inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=static_model_save_dir,
executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX))
dynamic_ops = dynamic_inference_program.global_block().ops
static_ops = static_inference_program.global_block().ops
for op in dynamic_ops[:]:
if op.type == "flatten2" or 'fake' in op.type:
dynamic_ops.remove(op)
for op in static_ops[:]:
if 'fake' in op.type:
static_ops.remove(op)
for i in range(len(dynamic_ops)):
if dynamic_ops[i].has_attr("out_threshold"):
self.assertTrue(dynamic_ops[i].type == static_ops[i].type)
self.assertTrue(dynamic_ops[i].attr("out_threshold") ==
static_ops[i].attr("out_threshold"))
class TestSaveQuantizedModel_Warning(unittest.TestCase):
def test_warning(self):
path = "./dynamic_outscale_infer_model_with_warnings/lenet"
imperative_out_scale = ImperativeQuantAware()
with fluid.dygraph.guard():
lenet = ImperativeLenet()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
imperative_out_scale.save_quantized_model(
layer=lenet,
path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
warning_message = "Warning: No Layer of the model while to be saved contains the out_threshold attribute, " \
"so the generated inference model would not contain the out_threshold."
num = get_vaild_warning_num(warning_message, w)
assert num == 1
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册