diff --git a/paddleslim/analysis/__init__.py b/paddleslim/analysis/__init__.py index df10ad516dbdf11a65cd4b6fea1d94e9b047be67..ba30217fd72cb5fd3261ef732d65cbef5cae6ca2 100644 --- a/paddleslim/analysis/__init__.py +++ b/paddleslim/analysis/__init__.py @@ -16,18 +16,10 @@ from .model_size import model_size from .latency import LatencyEvaluator, TableLatencyEvaluator from .latency_predictor import LatencyPredictor, TableLatencyPredictor from .parse_ops import get_key_from_op -from ._utils import save_cls_model, save_det_model, save_seg_model +from ._utils import save_cls_model, save_det_model __all__ = [ - 'flops', - 'dygraph_flops', - 'model_size', - 'LatencyEvaluator', - 'TableLatencyEvaluator', - "LatencyPredictor", - "TableLatencyPredictor", - "get_key_from_op", - "save_cls_model", - "save_det_model", - "save_seg_model", + 'flops', 'dygraph_flops', 'model_size', 'LatencyEvaluator', + 'TableLatencyEvaluator', "LatencyPredictor", "TableLatencyPredictor", + "get_key_from_op", "save_cls_model", "save_det_model" ] diff --git a/paddleslim/analysis/_utils.py b/paddleslim/analysis/_utils.py index d0f0d95b304392f2d5e5042aea3a9edc6695759b..c9fba4a7765625e439496a45c1f1b691a6e9d0a1 100644 --- a/paddleslim/analysis/_utils.py +++ b/paddleslim/analysis/_utils.py @@ -20,8 +20,8 @@ import paddleslim import subprocess import time __all__ = [ - "save_cls_model", "save_det_model", "save_seg_model", "nearest_interpolate", - "opt_model", "load_predictor" + "save_cls_model", "save_det_model", "nearest_interpolate", "opt_model", + "load_predictor" ] @@ -30,8 +30,7 @@ def opt_model(opt="paddle_lite_opt", param_file='', optimize_out_type='protobuf', valid_targets='arm', - enable_fp16=False, - sparse_ratio=0): + enable_fp16=False): assert os.path.exists(model_file) and os.path.exists( param_file), f'{model_file} or {param_file} does not exist.' save_dir = f'./opt_models_tmp/{os.getpid()}_{time.time()}' @@ -40,15 +39,13 @@ def opt_model(opt="paddle_lite_opt", assert optimize_out_type in ['protobuf', 'naive_buffer'] if optimize_out_type == 'protobuf': - model_out = os.path.join(save_dir, 'pbmodel') + model_out = save_dir else: model_out = os.path.join(save_dir, 'model') enable_fp16 = str(enable_fp16).lower() - sparse_model = True if sparse_ratio > 0 else False - sparse_threshold = max(sparse_ratio - 0.1, 0.1) - cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16} --sparse_model={sparse_model} --sparse_threshold={sparse_threshold}' + cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16} --sparse_model=true --sparse_threshold=0.4' print(f'commands:{cmd}') m = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) @@ -169,48 +166,6 @@ def save_det_model(model, return model_file, param_file -def save_seg_model(model, input_shape, save_dir, data_type): - if data_type == 'fp32': - paddle.jit.save( - model, - path=os.path.join(save_dir, 'fp32model'), - input_spec=[ - paddle.static.InputSpec( - shape=input_shape, dtype='float32', name='x'), - ]) - model_file = os.path.join(save_dir, 'fp32model.pdmodel') - param_file = os.path.join(save_dir, 'fp32model.pdiparams') - - else: - save_dir = os.path.join(save_dir, 'int8model') - quant_config = { - 'weight_preprocess_type': None, - 'activation_preprocess_type': None, - 'weight_quantize_type': 'channel_wise_abs_max', - 'activation_quantize_type': 'moving_average_abs_max', - 'weight_bits': 8, - 'activation_bits': 8, - 'dtype': 'int8', - 'window_size': 10000, - 'moving_rate': 0.9, - 'quantizable_layer_type': ['Conv2D', 'Linear'], - } - quantizer = paddleslim.QAT(config=quant_config) - quantizer.quantize(model) - quantizer.save_quantized_model( - model, - save_dir, - input_spec=[ - paddle.static.InputSpec( - shape=input_shape, dtype='float32') - ]) - - model_file = f'{save_dir}.pdmodel' - param_file = f'{save_dir}.pdiparams' - - return model_file, param_file - - def nearest_interpolate(features, data): def distance(x, y): x = np.array(x) diff --git a/paddleslim/analysis/latency_predictor.py b/paddleslim/analysis/latency_predictor.py index a67e46fd5ef32d2e2dcc774cd87f74cc652e3f82..701570635eefbbc872aaee45504d9bbc00f127fa 100644 --- a/paddleslim/analysis/latency_predictor.py +++ b/paddleslim/analysis/latency_predictor.py @@ -39,7 +39,7 @@ class LatencyPredictor(object): """Base class of latency predictor. """ - def predict_latency(self, model): + def predict(self, model): """Get latency of model. It is an abstract method. Args: @@ -64,6 +64,7 @@ class TableLatencyPredictor(LatencyPredictor): Args: table_file(str): The path of file that records the device latency of operators. """ + hardware_list = ['SD625', 'SD710'] def __init__(self, table_file='SD710'): self.table_file = table_file @@ -72,11 +73,14 @@ class TableLatencyPredictor(LatencyPredictor): self.threads = None self.predictor_state = False self.predictor = {} - self.hardware_list = ['SD625', 'SD710'] self._initial_table() + @classmethod + def add_hardware(cls, hardware): + cls.hardware_list.append(hardware) + def _initial_table(self): - if self.table_file in self.hardware_list: + if self.table_file in TableLatencyPredictor.hardware_list: self.hardware = self.table_file self.threads = 4 self.table_file = f'{self.hardware}_threads_4_power_mode_0.pkl' @@ -88,7 +92,7 @@ class TableLatencyPredictor(LatencyPredictor): assert os.path.exists( self.table_file - ), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in {self.hardware_list}' + ), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in {TableLatencyPredictor.hardware_list}' with open(self.table_file, 'rb') as f: self.table_dict = pickle.load(f) @@ -123,6 +127,8 @@ class TableLatencyPredictor(LatencyPredictor): ] op_dir = self.table_file.split('.')[0] + '_batchsize_1' for op_type in op_types: + if data_type == 'fp32' and op_type == 'calib': + continue model = load_predictor(op_type, op_dir, data_type) key = op_type if 'conv2d' in op_type: @@ -141,8 +147,6 @@ class TableLatencyPredictor(LatencyPredictor): model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams). data_type(str): Data type, fp32, fp16 or int8. threads(int): Threads num. - sparse_ratio(float): The ratio of unstructured pruning. - prune_ratio(float): The ration of structured pruning. input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for input shape that has variable length. Returns: latency(float): The latency of the model. diff --git a/tests/test_latency_predictor.py b/tests/test_latency_predictor.py index e44fdf29a728c7930926191476acedccdc2baf8d..1767f2550fd5788ce71d14d51caf44fe9b37a408 100644 --- a/tests/test_latency_predictor.py +++ b/tests/test_latency_predictor.py @@ -19,7 +19,7 @@ import paddleslim from paddleslim.analysis import LatencyPredictor, TableLatencyPredictor from paddle.vision.models import mobilenet_v1, mobilenet_v2 from paddle.nn import Conv2D, BatchNorm2D, ReLU, LayerNorm -from paddleslim.analysis._utils import opt_model, save_cls_model, save_seg_model, save_det_model +from paddleslim.analysis._utils import opt_model, save_cls_model, save_det_model def channel_shuffle(x, groups): @@ -276,7 +276,7 @@ class TestCase5(unittest.TestCase): paddle.disable_static() model = mobilenet_v1() predictor = TableLatencyPredictor(table_file='SD710') - model_file, param_file = save_seg_model( + model_file, param_file = save_cls_model( model, input_shape=[1, 3, 224, 224], save_dir="./inference_model", @@ -367,30 +367,6 @@ class TestCase9(unittest.TestCase): class TestCase10(unittest.TestCase): def test_case10(self): - paddle.disable_static() - model = ModelCase1() - predictor = LatencyPredictor() - model_file, param_file = save_seg_model( - model, - input_shape=[1, 116, 28, 28], - save_dir="./inference_model", - data_type='int8') - pbmodel_file = opt_model( - model_file=model_file, - param_file=param_file, - optimize_out_type='protobuf') - - paddle.enable_static() - with open(pbmodel_file, "rb") as f: - fluid_program = paddle.fluid.framework.Program.parse_from_string( - f.read()) - graph = paddleslim.core.GraphWrapper(fluid_program) - graph_keys = predictor._get_key_info_from_graph(graph=graph) - assert len(graph_keys) > 0 - - -class TestCase11(unittest.TestCase): - def test_case11(self): paddle.disable_static() model = mobilenet_v2() model2 = ModelCase6()