diff --git a/paddleslim/analysis/_utils.py b/paddleslim/analysis/_utils.py index 0a2a06f5accd92303ddbb9fadbab36226dd2ad92..d0f0d95b304392f2d5e5042aea3a9edc6695759b 100644 --- a/paddleslim/analysis/_utils.py +++ b/paddleslim/analysis/_utils.py @@ -30,7 +30,8 @@ def opt_model(opt="paddle_lite_opt", param_file='', optimize_out_type='protobuf', valid_targets='arm', - enable_fp16=False): + enable_fp16=False, + sparse_ratio=0): assert os.path.exists(model_file) and os.path.exists( param_file), f'{model_file} or {param_file} does not exist.' save_dir = f'./opt_models_tmp/{os.getpid()}_{time.time()}' @@ -42,8 +43,12 @@ def opt_model(opt="paddle_lite_opt", model_out = os.path.join(save_dir, 'pbmodel') else: model_out = os.path.join(save_dir, 'model') + enable_fp16 = str(enable_fp16).lower() - cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16}' + sparse_model = True if sparse_ratio > 0 else False + sparse_threshold = max(sparse_ratio - 0.1, 0.1) + + cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16} --sparse_model={sparse_model} --sparse_threshold={sparse_threshold}' print(f'commands:{cmd}') m = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) diff --git a/paddleslim/analysis/latency_predictor.py b/paddleslim/analysis/latency_predictor.py index 6760c07e5a9b6a0d06dd0d02703f305d93cda654..a67e46fd5ef32d2e2dcc774cd87f74cc652e3f82 100644 --- a/paddleslim/analysis/latency_predictor.py +++ b/paddleslim/analysis/latency_predictor.py @@ -72,15 +72,15 @@ class TableLatencyPredictor(LatencyPredictor): self.threads = None self.predictor_state = False self.predictor = {} + self.hardware_list = ['SD625', 'SD710'] self._initial_table() def _initial_table(self): - if self.table_file in ['SD625', 'SD710', 'SD845', 'SD865']: + if self.table_file in self.hardware_list: self.hardware = self.table_file self.threads = 4 self.table_file = f'{self.hardware}_threads_4_power_mode_0.pkl' - if self.hardware in ['SD625', 'SD710']: - self.predictor_state = True + self.predictor_state = True if not os.path.exists(self.table_file): subprocess.call( f'wget https://paddlemodels.bj.bcebos.com/PaddleSlim/analysis/{self.table_file}', @@ -88,7 +88,7 @@ class TableLatencyPredictor(LatencyPredictor): assert os.path.exists( self.table_file - ), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in [SD625, SD710, SD845, SD865]' + ), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in {self.hardware_list}' with open(self.table_file, 'rb') as f: self.table_dict = pickle.load(f) @@ -140,7 +140,9 @@ class TableLatencyPredictor(LatencyPredictor): Args: model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams). data_type(str): Data type, fp32, fp16 or int8. - threads(int): threads num + threads(int): Threads num. + sparse_ratio(float): The ratio of unstructured pruning. + prune_ratio(float): The ration of structured pruning. input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for input shape that has variable length. Returns: latency(float): The latency of the model. diff --git a/paddleslim/analysis/parse_ops.py b/paddleslim/analysis/parse_ops.py index ecb8d5297b995ead752b97163b364d77510fb2d3..2e53342863e37e4bf46edd65ed64c846e6810364 100644 --- a/paddleslim/analysis/parse_ops.py +++ b/paddleslim/analysis/parse_ops.py @@ -21,22 +21,38 @@ def get_key_from_op(op): param_key = '' op_type = op.type() - if 'conv2d' in op_type: + if op_type == 'sparse_conv2d': + out_shape = op.all_outputs()[0].shape() + in_shape = op.inputs('Input')[0].shape() + weight_shape = (out_shape[1], in_shape[1], 1, 1) + NonZeroWeights = op.inputs('NonZeroWeights')[0].shape()[0] + + stride = op.attr('strides')[1] + padding = op.attr('paddings')[1] + groups = op.attr('groups') + dilation = op.attr('dilations')[1] + quant = op.attr('enable_int8') + bit_length = op.attr('bit_length') + + param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={quant} bit_length={bit_length} NonZeroWeights={NonZeroWeights}' + + elif 'conv2d' in op_type: out_shape = op.all_outputs()[0].shape() in_shape = op.all_inputs()[-1].shape() in_name = op.all_inputs()[1].name() weight_shape = op.all_inputs()[-2].shape() - weight_shape = (out_shape[1], weight_shape[1], weight_shape[2], weight_shape[3]) - + weight_shape = (out_shape[1], weight_shape[1], weight_shape[2], + weight_shape[3]) + stride = op.attr('strides')[1] padding = op.attr('paddings')[1] groups = op.attr('groups') dilation = op.attr('dilations')[1] quant = op.attr('enable_int8') bit_length = op.attr('bit_length') - if op.attr(in_name+'_fp16') == 'fp16': + if op.attr(in_name + '_fp16') == 'fp16': quant = True - bit_length = 16 + bit_length = 16 param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={quant} bit_length={bit_length}' diff --git a/paddleslim/auto_compression/utils/prune_model.py b/paddleslim/auto_compression/utils/prune_model.py new file mode 100644 index 0000000000000000000000000000000000000000..67b775ef64060ea56c4394d7f4063cdde0634c59 --- /dev/null +++ b/paddleslim/auto_compression/utils/prune_model.py @@ -0,0 +1,169 @@ +import os +import time +import paddle +import paddle.fluid as fluid +import paddle.static as static +from paddleslim.prune import Pruner +from paddleslim.core import GraphWrapper +import numpy as np +__all__ = ["get_sparse_model", "get_prune_model"] + + +def get_sparse_model(model_file, param_file, ratio, save_path): + """ + Using the unstructured sparse algorithm to compress the network. + This interface is only used to evaluate the latency of the compressed network, and does not consider the loss of accuracy. + Args: + model_file(str), param_file(str): The inference model to be pruned. + ratio(float): The ratio to prune the model. + save_path(str): The save path of pruned model. + """ + assert os.path.exists(model_file), f'{model_file} does not exist.' + assert os.path.exists( + param_file) or param_file is None, f'{param_file} does not exist.' + paddle.enable_static() + + SKIP = ['image', 'feed', 'pool2d_0.tmp_0'] + + folder = os.path.dirname(model_file) + model_name = model_file.split('/')[-1] + model_name = model_file.split('/')[-1] + if param_file is None: + param_name = None + else: + param_name = param_file.split('/')[-1] + + main_prog = static.Program() + startup_prog = static.Program() + exe = paddle.static.Executor(paddle.CPUPlace()) + exe.run(startup_prog) + + [inference_program, feed_target_names, fetch_targets] = ( + fluid.io.load_inference_model( + folder, exe, model_filename=model_name, params_filename=param_name)) + thresholds = {} + + graph = GraphWrapper(inference_program) + for op in graph.ops(): + for inp in op.all_inputs(): + name = inp.name() + if inp.name() in SKIP: continue + if 'tmp' in inp.name(): continue + # 1x1_conv + cond_conv = len(inp._var.shape) == 4 and inp._var.shape[ + 2] == 1 and inp._var.shape[3] == 1 + cond_fc = False + + if cond_fc or cond_conv: + array = np.array(paddle.static.global_scope().find_var(name) + .get_tensor()) + flatten = np.abs(array.flatten()) + index = min(len(flatten) - 1, int(ratio * len(flatten))) + ind = np.unravel_index( + np.argsort( + flatten, axis=None), flatten.shape) + thresholds[name] = ind[0][:index] + + for op in graph.ops(): + for inp in op.all_inputs(): + name = inp.name() + if name in SKIP: continue + if 'tmp' in inp.name(): continue + + cond_conv = (len(inp._var.shape) == 4 and inp._var.shape[2] == 1 and + inp._var.shape[3] == 1) + cond_fc = False + + # only support 1x1_conv now + if not (cond_conv or cond_fc): continue + array = np.array(paddle.static.global_scope().find_var(name) + .get_tensor()) + if thresholds.get(name) is not None: + np.put(array, thresholds.get(name), 0) + assert (abs(1 - np.count_nonzero(array) / array.size - ratio) < 1e-2 + ), 'The model sparsity is abnormal.' + paddle.static.global_scope().find_var(name).get_tensor().set( + array, paddle.CPUPlace()) + + fluid.io.save_inference_model( + save_path, + feeded_var_names=feed_target_names, + target_vars=fetch_targets, + executor=exe, + main_program=inference_program, + model_filename=model_name, + params_filename=param_name) + print("The pruned model is saved in: ", save_path) + + +def get_prune_model(model_file, param_file, ratio, save_path): + """ + Using the structured pruning algorithm to compress the network. + This interface is only used to evaluate the latency of the compressed network, and does not consider the loss of accuracy. + Args: + model_file(str), param_file(str): The inference model to be pruned. + ratio(float): The ratio to prune the model. + save_path(str): The save path of pruned model. + """ + + assert os.path.exists(model_file), f'{model_file} does not exist.' + assert os.path.exists( + param_file) or param_file is None, f'{param_file} does not exist.' + paddle.enable_static() + + SKIP = ['image', 'feed', 'pool2d_0.tmp_0'] + + folder = os.path.dirname(model_file) + model_name = model_file.split('/')[-1] + if param_file is None: + param_name = None + else: + param_name = param_file.split('/')[-1] + + main_prog = static.Program() + startup_prog = static.Program() + place = paddle.CPUPlace() + exe = paddle.static.Executor() + scope = static.global_scope() + exe.run(startup_prog) + + [inference_program, feed_target_names, fetch_targets] = ( + fluid.io.load_inference_model( + folder, exe, model_filename=model_name, params_filename=param_name)) + + prune_params = [] + graph = GraphWrapper(inference_program) + for op in graph.ops(): + for inp in op.all_inputs(): + name = inp.name() + if inp.name() in SKIP: continue + if 'tmp' in inp.name(): continue + cond_conv = len(inp._var.shape) == 4 and 'conv' in name + # only prune conv + if cond_conv: + prune_params.append(name) + + # drop last conv + prune_params.pop() + ratios = [ratio] * len(prune_params) + + pruner = Pruner() + main_program, _, _ = pruner.prune( + inference_program, + scope, + params=prune_params, + ratios=ratios, + place=place, + lazy=False, + only_graph=False, + param_backup=None, + param_shape_backup=None) + + fluid.io.save_inference_model( + save_path, + feeded_var_names=feed_target_names, + target_vars=fetch_targets, + executor=exe, + main_program=main_program, + model_filename=model_name, + params_filename=param_name) diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py index f2dd05272c6de67bb6963efb1679985a84089b9b..b0a9f14e5ad022cdb26f1986e1a7796b723731d0 100644 --- a/paddleslim/core/graph_wrapper.py +++ b/paddleslim/core/graph_wrapper.py @@ -387,5 +387,7 @@ class GraphWrapper(object): It is used after loading pruned parameters from file. """ for op in self.ops(): - if op.type() != 'conditional_block' and op.type() != 'feed': + if op.type() in ['feed', 'fetch']: + continue + if op.type() != 'conditional_block': op._op.desc.infer_shape(op._op.block.desc)