未验证 提交 dd85e83a 编写于 作者: Z zhouzj 提交者: GitHub

Update predictor api (#1024)

* fix a pruner's bug

* support predict pruned model

* support predict pruned model

* support sparse model

* add interface for quick pruning

* add hardware list for latencyPredictor
上级 dc44c944
......@@ -30,7 +30,8 @@ def opt_model(opt="paddle_lite_opt",
param_file='',
optimize_out_type='protobuf',
valid_targets='arm',
enable_fp16=False):
enable_fp16=False,
sparse_ratio=0):
assert os.path.exists(model_file) and os.path.exists(
param_file), f'{model_file} or {param_file} does not exist.'
save_dir = f'./opt_models_tmp/{os.getpid()}_{time.time()}'
......@@ -42,8 +43,12 @@ def opt_model(opt="paddle_lite_opt",
model_out = os.path.join(save_dir, 'pbmodel')
else:
model_out = os.path.join(save_dir, 'model')
enable_fp16 = str(enable_fp16).lower()
cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16}'
sparse_model = True if sparse_ratio > 0 else False
sparse_threshold = max(sparse_ratio - 0.1, 0.1)
cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16} --sparse_model={sparse_model} --sparse_threshold={sparse_threshold}'
print(f'commands:{cmd}')
m = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
......
......@@ -72,15 +72,15 @@ class TableLatencyPredictor(LatencyPredictor):
self.threads = None
self.predictor_state = False
self.predictor = {}
self.hardware_list = ['SD625', 'SD710']
self._initial_table()
def _initial_table(self):
if self.table_file in ['SD625', 'SD710', 'SD845', 'SD865']:
if self.table_file in self.hardware_list:
self.hardware = self.table_file
self.threads = 4
self.table_file = f'{self.hardware}_threads_4_power_mode_0.pkl'
if self.hardware in ['SD625', 'SD710']:
self.predictor_state = True
self.predictor_state = True
if not os.path.exists(self.table_file):
subprocess.call(
f'wget https://paddlemodels.bj.bcebos.com/PaddleSlim/analysis/{self.table_file}',
......@@ -88,7 +88,7 @@ class TableLatencyPredictor(LatencyPredictor):
assert os.path.exists(
self.table_file
), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in [SD625, SD710, SD845, SD865]'
), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in {self.hardware_list}'
with open(self.table_file, 'rb') as f:
self.table_dict = pickle.load(f)
......@@ -140,7 +140,9 @@ class TableLatencyPredictor(LatencyPredictor):
Args:
model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams).
data_type(str): Data type, fp32, fp16 or int8.
threads(int): threads num
threads(int): Threads num.
sparse_ratio(float): The ratio of unstructured pruning.
prune_ratio(float): The ration of structured pruning.
input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for input shape that has variable length.
Returns:
latency(float): The latency of the model.
......
......@@ -21,22 +21,38 @@ def get_key_from_op(op):
param_key = ''
op_type = op.type()
if 'conv2d' in op_type:
if op_type == 'sparse_conv2d':
out_shape = op.all_outputs()[0].shape()
in_shape = op.inputs('Input')[0].shape()
weight_shape = (out_shape[1], in_shape[1], 1, 1)
NonZeroWeights = op.inputs('NonZeroWeights')[0].shape()[0]
stride = op.attr('strides')[1]
padding = op.attr('paddings')[1]
groups = op.attr('groups')
dilation = op.attr('dilations')[1]
quant = op.attr('enable_int8')
bit_length = op.attr('bit_length')
param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={quant} bit_length={bit_length} NonZeroWeights={NonZeroWeights}'
elif 'conv2d' in op_type:
out_shape = op.all_outputs()[0].shape()
in_shape = op.all_inputs()[-1].shape()
in_name = op.all_inputs()[1].name()
weight_shape = op.all_inputs()[-2].shape()
weight_shape = (out_shape[1], weight_shape[1], weight_shape[2], weight_shape[3])
weight_shape = (out_shape[1], weight_shape[1], weight_shape[2],
weight_shape[3])
stride = op.attr('strides')[1]
padding = op.attr('paddings')[1]
groups = op.attr('groups')
dilation = op.attr('dilations')[1]
quant = op.attr('enable_int8')
bit_length = op.attr('bit_length')
if op.attr(in_name+'_fp16') == 'fp16':
if op.attr(in_name + '_fp16') == 'fp16':
quant = True
bit_length = 16
bit_length = 16
param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={quant} bit_length={bit_length}'
......
import os
import time
import paddle
import paddle.fluid as fluid
import paddle.static as static
from paddleslim.prune import Pruner
from paddleslim.core import GraphWrapper
import numpy as np
__all__ = ["get_sparse_model", "get_prune_model"]
def get_sparse_model(model_file, param_file, ratio, save_path):
"""
Using the unstructured sparse algorithm to compress the network.
This interface is only used to evaluate the latency of the compressed network, and does not consider the loss of accuracy.
Args:
model_file(str), param_file(str): The inference model to be pruned.
ratio(float): The ratio to prune the model.
save_path(str): The save path of pruned model.
"""
assert os.path.exists(model_file), f'{model_file} does not exist.'
assert os.path.exists(
param_file) or param_file is None, f'{param_file} does not exist.'
paddle.enable_static()
SKIP = ['image', 'feed', 'pool2d_0.tmp_0']
folder = os.path.dirname(model_file)
model_name = model_file.split('/')[-1]
model_name = model_file.split('/')[-1]
if param_file is None:
param_name = None
else:
param_name = param_file.split('/')[-1]
main_prog = static.Program()
startup_prog = static.Program()
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(startup_prog)
[inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
folder, exe, model_filename=model_name, params_filename=param_name))
thresholds = {}
graph = GraphWrapper(inference_program)
for op in graph.ops():
for inp in op.all_inputs():
name = inp.name()
if inp.name() in SKIP: continue
if 'tmp' in inp.name(): continue
# 1x1_conv
cond_conv = len(inp._var.shape) == 4 and inp._var.shape[
2] == 1 and inp._var.shape[3] == 1
cond_fc = False
if cond_fc or cond_conv:
array = np.array(paddle.static.global_scope().find_var(name)
.get_tensor())
flatten = np.abs(array.flatten())
index = min(len(flatten) - 1, int(ratio * len(flatten)))
ind = np.unravel_index(
np.argsort(
flatten, axis=None), flatten.shape)
thresholds[name] = ind[0][:index]
for op in graph.ops():
for inp in op.all_inputs():
name = inp.name()
if name in SKIP: continue
if 'tmp' in inp.name(): continue
cond_conv = (len(inp._var.shape) == 4 and inp._var.shape[2] == 1 and
inp._var.shape[3] == 1)
cond_fc = False
# only support 1x1_conv now
if not (cond_conv or cond_fc): continue
array = np.array(paddle.static.global_scope().find_var(name)
.get_tensor())
if thresholds.get(name) is not None:
np.put(array, thresholds.get(name), 0)
assert (abs(1 - np.count_nonzero(array) / array.size - ratio) < 1e-2
), 'The model sparsity is abnormal.'
paddle.static.global_scope().find_var(name).get_tensor().set(
array, paddle.CPUPlace())
fluid.io.save_inference_model(
save_path,
feeded_var_names=feed_target_names,
target_vars=fetch_targets,
executor=exe,
main_program=inference_program,
model_filename=model_name,
params_filename=param_name)
print("The pruned model is saved in: ", save_path)
def get_prune_model(model_file, param_file, ratio, save_path):
"""
Using the structured pruning algorithm to compress the network.
This interface is only used to evaluate the latency of the compressed network, and does not consider the loss of accuracy.
Args:
model_file(str), param_file(str): The inference model to be pruned.
ratio(float): The ratio to prune the model.
save_path(str): The save path of pruned model.
"""
assert os.path.exists(model_file), f'{model_file} does not exist.'
assert os.path.exists(
param_file) or param_file is None, f'{param_file} does not exist.'
paddle.enable_static()
SKIP = ['image', 'feed', 'pool2d_0.tmp_0']
folder = os.path.dirname(model_file)
model_name = model_file.split('/')[-1]
if param_file is None:
param_name = None
else:
param_name = param_file.split('/')[-1]
main_prog = static.Program()
startup_prog = static.Program()
place = paddle.CPUPlace()
exe = paddle.static.Executor()
scope = static.global_scope()
exe.run(startup_prog)
[inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
folder, exe, model_filename=model_name, params_filename=param_name))
prune_params = []
graph = GraphWrapper(inference_program)
for op in graph.ops():
for inp in op.all_inputs():
name = inp.name()
if inp.name() in SKIP: continue
if 'tmp' in inp.name(): continue
cond_conv = len(inp._var.shape) == 4 and 'conv' in name
# only prune conv
if cond_conv:
prune_params.append(name)
# drop last conv
prune_params.pop()
ratios = [ratio] * len(prune_params)
pruner = Pruner()
main_program, _, _ = pruner.prune(
inference_program,
scope,
params=prune_params,
ratios=ratios,
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
fluid.io.save_inference_model(
save_path,
feeded_var_names=feed_target_names,
target_vars=fetch_targets,
executor=exe,
main_program=main_program,
model_filename=model_name,
params_filename=param_name)
......@@ -387,5 +387,7 @@ class GraphWrapper(object):
It is used after loading pruned parameters from file.
"""
for op in self.ops():
if op.type() != 'conditional_block' and op.type() != 'feed':
if op.type() in ['feed', 'fetch']:
continue
if op.type() != 'conditional_block':
op._op.desc.infer_shape(op._op.block.desc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册