未验证 提交 46751fe7 编写于 作者: Z zhouzj 提交者: GitHub

Method of adding hardware (#1041)

* fix opt's cmd for sparse model

* add hardware

* Remove redundant functions and adjust tests' file

* Remove redundant functions and adjust tests' file
上级 cb642009
......@@ -16,18 +16,10 @@ from .model_size import model_size
from .latency import LatencyEvaluator, TableLatencyEvaluator
from .latency_predictor import LatencyPredictor, TableLatencyPredictor
from .parse_ops import get_key_from_op
from ._utils import save_cls_model, save_det_model, save_seg_model
from ._utils import save_cls_model, save_det_model
__all__ = [
'flops',
'dygraph_flops',
'model_size',
'LatencyEvaluator',
'TableLatencyEvaluator',
"LatencyPredictor",
"TableLatencyPredictor",
"get_key_from_op",
"save_cls_model",
"save_det_model",
"save_seg_model",
'flops', 'dygraph_flops', 'model_size', 'LatencyEvaluator',
'TableLatencyEvaluator', "LatencyPredictor", "TableLatencyPredictor",
"get_key_from_op", "save_cls_model", "save_det_model"
]
......@@ -20,8 +20,8 @@ import paddleslim
import subprocess
import time
__all__ = [
"save_cls_model", "save_det_model", "save_seg_model", "nearest_interpolate",
"opt_model", "load_predictor"
"save_cls_model", "save_det_model", "nearest_interpolate", "opt_model",
"load_predictor"
]
......@@ -30,8 +30,7 @@ def opt_model(opt="paddle_lite_opt",
param_file='',
optimize_out_type='protobuf',
valid_targets='arm',
enable_fp16=False,
sparse_ratio=0):
enable_fp16=False):
assert os.path.exists(model_file) and os.path.exists(
param_file), f'{model_file} or {param_file} does not exist.'
save_dir = f'./opt_models_tmp/{os.getpid()}_{time.time()}'
......@@ -40,15 +39,13 @@ def opt_model(opt="paddle_lite_opt",
assert optimize_out_type in ['protobuf', 'naive_buffer']
if optimize_out_type == 'protobuf':
model_out = os.path.join(save_dir, 'pbmodel')
model_out = save_dir
else:
model_out = os.path.join(save_dir, 'model')
enable_fp16 = str(enable_fp16).lower()
sparse_model = True if sparse_ratio > 0 else False
sparse_threshold = max(sparse_ratio - 0.1, 0.1)
cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16} --sparse_model={sparse_model} --sparse_threshold={sparse_threshold}'
cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16} --sparse_model=true --sparse_threshold=0.4'
print(f'commands:{cmd}')
m = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
......@@ -169,48 +166,6 @@ def save_det_model(model,
return model_file, param_file
def save_seg_model(model, input_shape, save_dir, data_type):
if data_type == 'fp32':
paddle.jit.save(
model,
path=os.path.join(save_dir, 'fp32model'),
input_spec=[
paddle.static.InputSpec(
shape=input_shape, dtype='float32', name='x'),
])
model_file = os.path.join(save_dir, 'fp32model.pdmodel')
param_file = os.path.join(save_dir, 'fp32model.pdiparams')
else:
save_dir = os.path.join(save_dir, 'int8model')
quant_config = {
'weight_preprocess_type': None,
'activation_preprocess_type': None,
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'weight_bits': 8,
'activation_bits': 8,
'dtype': 'int8',
'window_size': 10000,
'moving_rate': 0.9,
'quantizable_layer_type': ['Conv2D', 'Linear'],
}
quantizer = paddleslim.QAT(config=quant_config)
quantizer.quantize(model)
quantizer.save_quantized_model(
model,
save_dir,
input_spec=[
paddle.static.InputSpec(
shape=input_shape, dtype='float32')
])
model_file = f'{save_dir}.pdmodel'
param_file = f'{save_dir}.pdiparams'
return model_file, param_file
def nearest_interpolate(features, data):
def distance(x, y):
x = np.array(x)
......
......@@ -39,7 +39,7 @@ class LatencyPredictor(object):
"""Base class of latency predictor.
"""
def predict_latency(self, model):
def predict(self, model):
"""Get latency of model. It is an abstract method.
Args:
......@@ -64,6 +64,7 @@ class TableLatencyPredictor(LatencyPredictor):
Args:
table_file(str): The path of file that records the device latency of operators.
"""
hardware_list = ['SD625', 'SD710']
def __init__(self, table_file='SD710'):
self.table_file = table_file
......@@ -72,11 +73,14 @@ class TableLatencyPredictor(LatencyPredictor):
self.threads = None
self.predictor_state = False
self.predictor = {}
self.hardware_list = ['SD625', 'SD710']
self._initial_table()
@classmethod
def add_hardware(cls, hardware):
cls.hardware_list.append(hardware)
def _initial_table(self):
if self.table_file in self.hardware_list:
if self.table_file in TableLatencyPredictor.hardware_list:
self.hardware = self.table_file
self.threads = 4
self.table_file = f'{self.hardware}_threads_4_power_mode_0.pkl'
......@@ -88,7 +92,7 @@ class TableLatencyPredictor(LatencyPredictor):
assert os.path.exists(
self.table_file
), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in {self.hardware_list}'
), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in {TableLatencyPredictor.hardware_list}'
with open(self.table_file, 'rb') as f:
self.table_dict = pickle.load(f)
......@@ -123,6 +127,8 @@ class TableLatencyPredictor(LatencyPredictor):
]
op_dir = self.table_file.split('.')[0] + '_batchsize_1'
for op_type in op_types:
if data_type == 'fp32' and op_type == 'calib':
continue
model = load_predictor(op_type, op_dir, data_type)
key = op_type
if 'conv2d' in op_type:
......@@ -141,8 +147,6 @@ class TableLatencyPredictor(LatencyPredictor):
model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams).
data_type(str): Data type, fp32, fp16 or int8.
threads(int): Threads num.
sparse_ratio(float): The ratio of unstructured pruning.
prune_ratio(float): The ration of structured pruning.
input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for input shape that has variable length.
Returns:
latency(float): The latency of the model.
......
......@@ -19,7 +19,7 @@ import paddleslim
from paddleslim.analysis import LatencyPredictor, TableLatencyPredictor
from paddle.vision.models import mobilenet_v1, mobilenet_v2
from paddle.nn import Conv2D, BatchNorm2D, ReLU, LayerNorm
from paddleslim.analysis._utils import opt_model, save_cls_model, save_seg_model, save_det_model
from paddleslim.analysis._utils import opt_model, save_cls_model, save_det_model
def channel_shuffle(x, groups):
......@@ -276,7 +276,7 @@ class TestCase5(unittest.TestCase):
paddle.disable_static()
model = mobilenet_v1()
predictor = TableLatencyPredictor(table_file='SD710')
model_file, param_file = save_seg_model(
model_file, param_file = save_cls_model(
model,
input_shape=[1, 3, 224, 224],
save_dir="./inference_model",
......@@ -367,30 +367,6 @@ class TestCase9(unittest.TestCase):
class TestCase10(unittest.TestCase):
def test_case10(self):
paddle.disable_static()
model = ModelCase1()
predictor = LatencyPredictor()
model_file, param_file = save_seg_model(
model,
input_shape=[1, 116, 28, 28],
save_dir="./inference_model",
data_type='int8')
pbmodel_file = opt_model(
model_file=model_file,
param_file=param_file,
optimize_out_type='protobuf')
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
fluid_program = paddle.fluid.framework.Program.parse_from_string(
f.read())
graph = paddleslim.core.GraphWrapper(fluid_program)
graph_keys = predictor._get_key_info_from_graph(graph=graph)
assert len(graph_keys) > 0
class TestCase11(unittest.TestCase):
def test_case11(self):
paddle.disable_static()
model = mobilenet_v2()
model2 = ModelCase6()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册