未验证 提交 1a9376a8 编写于 作者: Z zhouzj 提交者: GitHub

[LatencyPredictor] Add new op and speed up prediction (#1014)

* add new op type and support fp16 model

* preload predictors' model and speed up prediction

* preload predictors' model

* preload predictors' model

* Modified the save path of TMP files
Co-authored-by: NminghaoBD <79566150+minghaoBD@users.noreply.github.com>
上级 23cc74de
...@@ -18,7 +18,7 @@ import pickle ...@@ -18,7 +18,7 @@ import pickle
import paddle import paddle
import paddleslim import paddleslim
import subprocess import subprocess
import sklearn import time
__all__ = [ __all__ = [
"save_cls_model", "save_det_model", "save_seg_model", "nearest_interpolate", "save_cls_model", "save_det_model", "save_seg_model", "nearest_interpolate",
"opt_model", "load_predictor" "opt_model", "load_predictor"
...@@ -29,10 +29,11 @@ def opt_model(opt="paddle_lite_opt", ...@@ -29,10 +29,11 @@ def opt_model(opt="paddle_lite_opt",
model_file='', model_file='',
param_file='', param_file='',
optimize_out_type='protobuf', optimize_out_type='protobuf',
valid_targets='arm'): valid_targets='arm',
enable_fp16=False):
assert os.path.exists(model_file) and os.path.exists( assert os.path.exists(model_file) and os.path.exists(
param_file), f'{model_file} or {param_file} does not exist.' param_file), f'{model_file} or {param_file} does not exist.'
save_dir = f'./opt_models_tmp/{os.getpid()}' save_dir = f'./opt_models_tmp/{os.getpid()}_{time.time()}'
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
...@@ -41,8 +42,8 @@ def opt_model(opt="paddle_lite_opt", ...@@ -41,8 +42,8 @@ def opt_model(opt="paddle_lite_opt",
model_out = os.path.join(save_dir, 'pbmodel') model_out = os.path.join(save_dir, 'pbmodel')
else: else:
model_out = os.path.join(save_dir, 'model') model_out = os.path.join(save_dir, 'model')
enable_fp16 = str(enable_fp16).lower()
cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets}' cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16}'
print(f'commands:{cmd}') print(f'commands:{cmd}')
m = subprocess.Popen( m = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
......
...@@ -52,8 +52,13 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -52,8 +52,13 @@ def get_features_from_paramkey(param_key, op_type, data_type):
features = None features = None
if 'conv2d' in op_type: if 'conv2d' in op_type:
flag_quant = 'quant=None' if data_type == 'fp32' else 'quant=True' if data_type == 'fp16':
if flag_quant not in param_key: quant_bits = 'bit_length=16'
elif data_type == 'int8':
quant_bits = 'bit_length=8'
else:
quant_bits = 'bit_length=None'
if quant_bits not in param_key:
return None return None
weight = re.search(r'weight=(\(\d*, \d*, \d*, \d*\))', weight = re.search(r'weight=(\(\d*, \d*, \d*, \d*\))',
...@@ -178,7 +183,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -178,7 +183,7 @@ def get_features_from_paramkey(param_key, op_type, data_type):
'leaky_relu' in op_type or 'tanh' in op_type or 'swish' in op_type or 'leaky_relu' in op_type or 'tanh' in op_type or 'swish' in op_type or
'softmax' in op_type or 'hard_sigmoid' in op_type or 'softmax' in op_type or 'hard_sigmoid' in op_type or
'sigmoid' in op_type or 'gelu' in op_type or 'clip' in op_type or 'sigmoid' in op_type or 'gelu' in op_type or 'clip' in op_type or
'shape' in op_type or 'interp_v2' in op_type): 'shape' in op_type or 'interp_v2' in op_type or 'sqrt' in op_type):
inputs = re.search(r'in=(\((-?\d+,* *)+\))', inputs = re.search(r'in=(\((-?\d+,* *)+\))',
param_key).group().split('=')[-1].strip( param_key).group().split('=')[-1].strip(
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import os import os
import pickle import pickle
import time import shutil
import subprocess import subprocess
from .parse_ops import get_key_from_op from .parse_ops import get_key_from_op
from .extract_features import get_data_from_tables, get_features_from_paramkey from .extract_features import get_data_from_tables, get_features_from_paramkey
...@@ -71,15 +71,16 @@ class TableLatencyPredictor(LatencyPredictor): ...@@ -71,15 +71,16 @@ class TableLatencyPredictor(LatencyPredictor):
self.hardware = None self.hardware = None
self.threads = None self.threads = None
self.predictor_state = False self.predictor_state = False
self.predictor = {}
self._initial_table() self._initial_table()
def _initial_table(self): def _initial_table(self):
if self.table_file in ['SD625', 'SD710', 'SD845', 'SD865']: if self.table_file in ['SD625', 'SD710', 'SD845', 'SD865']:
self.hardware = self.table_file self.hardware = self.table_file
if self.hardware in ['SD625', 'SD710']:
self.predictor_state = True
self.threads = 4 self.threads = 4
self.table_file = f'{self.hardware}_threads_4_power_mode_0.pkl' self.table_file = f'{self.hardware}_threads_4_power_mode_0.pkl'
if self.hardware in ['SD625', 'SD710']:
self.predictor_state = True
if not os.path.exists(self.table_file): if not os.path.exists(self.table_file):
subprocess.call( subprocess.call(
f'wget https://paddlemodels.bj.bcebos.com/PaddleSlim/analysis/{self.table_file}', f'wget https://paddlemodels.bj.bcebos.com/PaddleSlim/analysis/{self.table_file}',
...@@ -115,6 +116,19 @@ class TableLatencyPredictor(LatencyPredictor): ...@@ -115,6 +116,19 @@ class TableLatencyPredictor(LatencyPredictor):
break break
return in_shape return in_shape
def _preload_predictor(self, data_type='fp32'):
op_types = [
'depthwise_conv2d', 'conv2d', 'pool2d', 'matmul', 'elementwise_add',
'elementwise_mul', 'concat', 'calib', 'swish'
]
op_dir = self.table_file.split('.')[0] + '_batchsize_1'
for op_type in op_types:
model = load_predictor(op_type, op_dir, data_type)
key = op_type
if 'conv2d' in op_type:
key = f'{op_type}_{data_type}'
self.predictor[key] = model
def predict(self, def predict(self,
model_file, model_file,
param_file, param_file,
...@@ -125,22 +139,27 @@ class TableLatencyPredictor(LatencyPredictor): ...@@ -125,22 +139,27 @@ class TableLatencyPredictor(LatencyPredictor):
Args: Args:
model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams). model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams).
data_type(str): Data type, fp32 or int8. Default : fp32 data_type(str): Data type, fp32, fp16 or int8.
threads(int): threads num threads(int): threads num
input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for input shape that has variable length. input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for input shape that has variable length.
Returns: Returns:
latency(float): The latency of the model. latency(float): The latency of the model.
""" """
assert data_type in ['fp32', 'int8' assert data_type in ['fp32', 'int8', 'fp16'
], f'data_type must be one of [fp32, int8]' ], f'data_type must be one of [fp32, int8, fp16]'
if self.hardware and self.threads != threads: if self.hardware and self.threads != threads:
self._change_table(threads) self._change_table(threads)
if self.predictor_state and f'conv2d_{data_type}' not in self.predictor:
self._preload_predictor(data_type)
enable_fp16 = True if data_type == 'fp16' else False
pbmodel_file = opt_model( pbmodel_file = opt_model(
model_file=model_file, model_file=model_file,
param_file=param_file, param_file=param_file,
optimize_out_type='protobuf', ) optimize_out_type='protobuf',
enable_fp16=enable_fp16)
paddle.enable_static() paddle.enable_static()
with open(pbmodel_file, "rb") as f: with open(pbmodel_file, "rb") as f:
...@@ -176,7 +195,7 @@ class TableLatencyPredictor(LatencyPredictor): ...@@ -176,7 +195,7 @@ class TableLatencyPredictor(LatencyPredictor):
warnings.warn("OperatorType\tCalledTimes") warnings.warn("OperatorType\tCalledTimes")
for key in new_op: for key in new_op:
warnings.warn(f"{key.ljust(15)}\t{new_op[key]}") warnings.warn(f"{key.ljust(15)}\t{new_op[key]}")
shutil.rmtree(os.path.dirname(pbmodel_file))
return latency return latency
def op_predictor(self, op_type, param_key, data_type): def op_predictor(self, op_type, param_key, data_type):
...@@ -185,18 +204,20 @@ class TableLatencyPredictor(LatencyPredictor): ...@@ -185,18 +204,20 @@ class TableLatencyPredictor(LatencyPredictor):
Args: Args:
op_type: The operator's type op_type: The operator's type
param_key: The operator's parameter information. param_key: The operator's parameter information.
data_type: Data type, fp32 or int8. Default : int8 data_type: Data type, fp32 or int8.
Returns: Returns:
latency(float): The latency of the operator. latency(float): The latency of the operator.
""" """
latency = 0.0 latency = 0.0
op_dir = self.table_file.split('.')[0] + '_batchsize_1'
if op_type in [ if op_type in [
'depthwise_conv2d', 'conv2d', 'pool2d', 'matmul', 'depthwise_conv2d', 'conv2d', 'pool2d', 'matmul',
'elementwise_add', 'elementwise_mul', 'concat', 'calib', 'swish' 'elementwise_add', 'elementwise_mul', 'concat', 'calib', 'swish'
]: ]:
predictor = load_predictor(op_type, op_dir, data_type) key = op_type
if 'conv2d' in op_type:
key = f'{op_type}_{data_type}'
predictor = self.predictor[key]
features = get_features_from_paramkey(param_key, op_type, data_type) features = get_features_from_paramkey(param_key, op_type, data_type)
latency = predictor.predict([features]) latency = predictor.predict([features])
else: else:
......
...@@ -24,25 +24,30 @@ def get_key_from_op(op): ...@@ -24,25 +24,30 @@ def get_key_from_op(op):
if 'conv2d' in op_type: if 'conv2d' in op_type:
out_shape = op.all_outputs()[0].shape() out_shape = op.all_outputs()[0].shape()
in_shape = op.all_inputs()[-1].shape() in_shape = op.all_inputs()[-1].shape()
in_name = op.all_inputs()[1].name()
weight_shape = op.all_inputs()[-2].shape() weight_shape = op.all_inputs()[-2].shape()
kernel = weight_shape[2] weight_shape = (out_shape[1], weight_shape[1], weight_shape[2], weight_shape[3])
stride = op.attr('strides')[1] stride = op.attr('strides')[1]
padding = op.attr('paddings')[1] padding = op.attr('paddings')[1]
groups = op.attr('groups') groups = op.attr('groups')
dilation = op.attr('dilations')[1] dilation = op.attr('dilations')[1]
int8 = op.attr('enable_int8') quant = op.attr('enable_int8')
bit_length = op.attr('bit_length') bit_length = op.attr('bit_length')
if op.attr(in_name+'_fp16') == 'fp16':
quant = True
bit_length = 16
param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={int8} bit_length={bit_length}' param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={quant} bit_length={bit_length}'
elif op_type == 'matmul' or op_type == 'matmul_v2': elif op_type == 'matmul' or op_type == 'matmul_v2':
X = op.all_inputs()[0].shape() X = op.all_inputs()[0].shape()
Y = op.all_inputs()[1].shape() Y = op.all_inputs()[1].shape()
out_shape = op.all_outputs()[0].shape() out_shape = op.all_outputs()[0].shape()
int8 = op.attr('enable_int8') quant = op.attr('enable_int8')
bit_length = op.attr('bit_length') bit_length = op.attr('bit_length')
param_key = f'{op_type} X={X} Y={Y} out={out_shape} quant={int8} bit_length={bit_length}' param_key = f'{op_type} X={X} Y={Y} out={out_shape} quant={quant} bit_length={bit_length}'
elif 'batch_norm' in op_type or 'layer_norm' in op_type: elif 'batch_norm' in op_type or 'layer_norm' in op_type:
out_shape = op.all_outputs()[-1].shape() out_shape = op.all_outputs()[-1].shape()
...@@ -67,14 +72,12 @@ def get_key_from_op(op): ...@@ -67,14 +72,12 @@ def get_key_from_op(op):
elif op_type in [ elif op_type in [
'hard_swish', 'relu', 'leaky_relu', 'tanh', 'swish', 'softmax', 'hard_swish', 'relu', 'leaky_relu', 'tanh', 'swish', 'softmax',
'hard_sigmoid', 'sigmoid', 'gelu', 'clip', 'shape' 'hard_sigmoid', 'sigmoid', 'gelu', 'clip', 'shape', 'sqrt'
] or 'transpose' in op_type or 'interp_v2' in op_type: ] or 'transpose' in op_type or 'interp_v2' in op_type:
in_shape = op.all_inputs()[-1].shape() in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} in={in_shape}' param_key = f'{op_type} in={in_shape} out={out_shape}'
in_shape = op.all_inputs()[-1].shape()
param_key = f'{op_type} in={in_shape}'
elif op_type in ['fill_constant', 'range', 'cast'] or 'expand' in op_type: elif op_type in ['fill_constant', 'range', 'cast'] or 'expand' in op_type:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册