未验证 提交 374a5f7c 编写于 作者: Z zhouzj 提交者: GitHub

add an interface to predict the latency of compressed model (#1045)

* add an interface to predict the latency of compressed model

* just for rerun

* delete temporary model files
Co-authored-by: Nceci3 <ceci3@users.noreply.github.com>
上级 2ef31f9a
import os
import paddle
from paddleslim.analysis import TableLatencyPredictor
from .prune_model import get_sparse_model, get_prune_model
from .fake_ptq import post_quant_fake
import shutil
def predict_compressed_model(model_file, param_file, hardware='SD710'):
"""
Evaluating the latency of the model under various compression strategies.
Args:
model_file(str), param_file(str): The inference model to be compressed.
hardware(str): Target device.
Returns:
latency_dict(dict): The latency latency of the model under various compression strategies.
"""
latency_dict = {}
model_filename = model_file.split('/')[-1]
param_filename = param_file.split('/')[-1]
predictor = TableLatencyPredictor(hardware)
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='fp32')
latency_dict.update({'origin_fp32': latency})
paddle.enable_static()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
post_quant_fake(
exe,
model_dir=os.path.dirname(model_file),
model_filename=model_filename,
params_filename=param_filename,
save_model_path='quant_model',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
activation_bits=8,
weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename)
quant_param_file = os.path.join('quant_model', param_filename)
latency = predictor.predict(
model_file=quant_model_file,
param_file=quant_param_file,
data_type='int8')
latency_dict.update({f'origin_int8': latency})
for prune_ratio in [0.3, 0.4, 0.5, 0.6]:
get_prune_model(
model_file=model_file,
param_file=param_file,
ratio=prune_ratio,
save_path='prune_model')
prune_model_file = os.path.join('prune_model', model_filename)
prune_param_file = os.path.join('prune_model', param_filename)
latency = predictor.predict(
model_file=prune_model_file,
param_file=prune_param_file,
data_type='fp32')
latency_dict.update({f'prune_{prune_ratio}_fp32': latency})
post_quant_fake(
exe,
model_dir='prune_model',
model_filename=model_filename,
params_filename=param_filename,
save_model_path='quant_model',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
activation_bits=8,
weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename)
quant_param_file = os.path.join('quant_model', param_filename)
latency = predictor.predict(
model_file=quant_model_file,
param_file=quant_param_file,
data_type='int8')
latency_dict.update({f'prune_{prune_ratio}_int8': latency})
for sparse_ratio in [0.70, 0.75, 0.80, 0.85, 0.90, 0.95]:
get_sparse_model(
model_file=model_file,
param_file=param_file,
ratio=sparse_ratio,
save_path='sparse_model')
sparse_model_file = os.path.join('sparse_model', model_filename)
sparse_param_file = os.path.join('sparse_model', param_filename)
latency = predictor.predict(
model_file=sparse_model_file,
param_file=sparse_param_file,
data_type='fp32')
latency_dict.update({f'sparse_{sparse_ratio}_fp32': latency})
post_quant_fake(
exe,
model_dir='sparse_model',
model_filename=model_filename,
params_filename=param_filename,
save_model_path='quant_model',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
activation_bits=8,
weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename)
quant_param_file = os.path.join('quant_model', param_filename)
latency = predictor.predict(
model_file=quant_model_file,
param_file=quant_param_file,
data_type='int8')
latency_dict.update({f'sparse_{prune_ratio}_int8': latency})
# Delete temporary model files
shutil.rmtree('./quant_model')
shutil.rmtree('./prune_model')
shutil.rmtree('./sparse_model')
return latency_dict
......@@ -27,7 +27,6 @@ def get_sparse_model(model_file, param_file, ratio, save_path):
folder = os.path.dirname(model_file)
model_name = model_file.split('/')[-1]
model_name = model_file.split('/')[-1]
if param_file is None:
param_name = None
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册