diff --git a/docs/zh_cn/tutorials/analysis/dygraph/latency_predictor.md b/docs/zh_cn/tutorials/analysis/dygraph/latency_predictor.md index 55d9951a2421a2402c02eccb1c0ad04968c46b13..860b96a37213dfaaae5c4447805fba07f5fa1b2b 100644 --- a/docs/zh_cn/tutorials/analysis/dygraph/latency_predictor.md +++ b/docs/zh_cn/tutorials/analysis/dygraph/latency_predictor.md @@ -37,9 +37,9 @@ tar -xf mobilenetv1.tar ### 2.2 预估推理延时 构造 TableLatencyPredictor 类实例,并调用 predict 函数预估推理模型的延时。 ``` -import paddleslim +from paddleslim.analysis import TableLatencyPredictor -predictor = paddleslim.TableLatencyPredictor(table_file='SD710') +predictor = TableLatencyPredictor(table_file='SD710') latency = predictor.predict(model_file='mobilenetv1_fp32.pdmodel', param_file='mobilenetv1_fp32.pdiparams, data_type='fp32') print('predicted latency = {}ms'.format(latency)) ``` @@ -58,9 +58,9 @@ print('predicted latency = {}ms'.format(latency)) ### 3.2 支持预估 INT8 模型 延时预估器支持对 INT8 量化模型进行延时预估,仅需提供 INT8 量化保存的推理模型文件,并将在调用 predict 函数时,设置 data_type='int8',如下所示: ``` -import paddleslim +from paddleslim.analysis import TableLatencyPredictor -predictor = paddleslim.TableLatencyPredictor(table_file='SD710') +predictor = TableLatencyPredictor(table_file='SD710') predictor.predict(model_file='mobilenetv1_int8.pdmodel', param_file='mobilenetv1_int8.pdiparams, data_type='int8') ``` diff --git a/paddleslim/analysis/_utils.py b/paddleslim/analysis/_utils.py index bb9bd5694fd62852136beb11ec72303359416cd0..bd674b5dc876a5b8b9cd7d33094f6cc2adbf20b5 100644 --- a/paddleslim/analysis/_utils.py +++ b/paddleslim/analysis/_utils.py @@ -31,7 +31,7 @@ def opt_model(opt="paddle_lite_opt", optimize_out_type='protobuf', valid_targets='arm'): assert os.path.exists(model_file) and os.path.exists( - param_file), f'{model_file} or {param_file} is not existed.' + param_file), f'{model_file} or {param_file} does not exist.' save_dir = f'./opt_models_tmp/{os.getpid()}' if not os.path.exists(save_dir): os.makedirs(save_dir) @@ -225,11 +225,11 @@ def nearest_interpolate(features, data): return latency[idx] -def dowload_predictor(op_dir, op): - """Dowload op predictors' model file +def download_predictor(op_dir, op): + """Download op predictors' model file Args: - op_dir(str): the dowload path of op predictor. Actually, it's the hardware information. + op_dir(str): the path to op predictor. Actually, it's the hardware information. op(str): the op type. Returns: op_path: The path of the file. @@ -252,7 +252,7 @@ def load_predictor(op_type, op_dir, data_type='fp32'): elif 'matmul' in op_type: op = 'matmul' - op_path = dowload_predictor(op_dir, op) + op_path = download_predictor(op_dir, op) with open(op_path, 'rb') as f: model = pickle.load(f) diff --git a/paddleslim/analysis/extract_features.py b/paddleslim/analysis/extract_features.py index 8410cb263c8d70a57247c60620b7e0e7cac31ce3..38f67bce924b00055854ac9cca67305411d1c073 100644 --- a/paddleslim/analysis/extract_features.py +++ b/paddleslim/analysis/extract_features.py @@ -38,7 +38,9 @@ def get_data_from_tables(table_dict, op_type, data_type='fp32'): features = get_features_from_paramkey(param_key, op_type, data_type) if features == None: continue - + # only support bs=1 now + if features[0] != 1: + continue features.append(table_dict[param_key]) data.append(features) return np.array(data) @@ -47,7 +49,7 @@ def get_data_from_tables(table_dict, op_type, data_type='fp32'): def get_features_from_paramkey(param_key, op_type, data_type): """Get op's parameters according to the key of latency table """ - features = [] + features = None if 'conv2d' in op_type: flag_quant = 'quant=None' if data_type == 'fp32' else 'quant=True' @@ -58,12 +60,12 @@ def get_features_from_paramkey(param_key, op_type, data_type): param_key).group().split('=')[-1].strip( '(' ')').split(', ') - outputs = re.search(r'out=(\(-*\d*, \d*, \d*, \d*\))', + outputs = re.search(r'out=(\(-*\d*, \d*, -?\d*, -?\d*\))', param_key).group().split('=')[-1].strip( '(' ')').split(', ') - - cout = int(weight[0]) + batchsize = int(outputs[0]) + cout = int(outputs[1]) cin = int(weight[1]) kernel = int(weight[2]) out_h = int(outputs[2]) @@ -75,7 +77,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): out_w) if data_type == 'fp32': - inputs = re.search(r'in=(\(-*\d*, \d*, \d*, \d*\))', + inputs = re.search(r'in=(\(-*\d*, \d*, -?\d*, -?\d*\))', param_key).group().split('=')[-1].strip( '(' ')').split(', ') @@ -84,20 +86,20 @@ def get_features_from_paramkey(param_key, op_type, data_type): in_w = int(inputs[3]) features = [ - in_c, cout, kernel, group, stride, pad, in_h * in_w, + batchsize, in_c, cout, kernel, group, stride, pad, in_h * in_w, out_h * out_w ] else: features = [ - cin, cout, kernel, group, stride, pad, out_h * out_w, flops, - params + batchsize, cin, cout, kernel, group, stride, pad, out_h * out_w, + flops, params ] elif 'matmul' in op_type: - X = re.search(r'X=(\(-*\d*, \d*\))', + X = re.search(r'X=(\((-?\d+,* *)+\))', param_key).group().split('=')[-1].strip('(' ')').split(', ') - Y = re.search(r'Y=(\(\d*, \d*\))', + Y = re.search(r'Y=(\((-?\d+,* *)+\))', param_key).group().split('=')[-1].strip('(' ')').split(', ') @@ -106,7 +108,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): c = int(Y[1]) flops, params = cal_flops_params('fc', b, c) - features = [b, c, flops, params] + features = [a, b, c, flops, params] elif ('batch_norm' in op_type or 'layer_norm' in op_type): inputs = re.search(r'in=(\((-?\d+,* *)+\))', @@ -114,11 +116,11 @@ def get_features_from_paramkey(param_key, op_type, data_type): '(' ')').split(', ') - features = [0, 0, 0] - for i in range(1, len(inputs)): + features = [0, 0, 0, 0] + for i in range(len(inputs)): if inputs[i] == '': continue - features[i - 1] = int(inputs[i]) + features[i] = int(inputs[i]) elif 'pool2d' in op_type: @@ -130,7 +132,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): param_key).group().split('=')[-1].strip( '(' ')').split(', ') - + batchsize = int(inputs[0]) cin = int(inputs[1]) in_h = int(inputs[2]) in_w = int(inputs[3]) @@ -147,7 +149,8 @@ def get_features_from_paramkey(param_key, op_type, data_type): flag_type = 1 if 'type=avg' in param_key else 0 features = [ - cin, kernel, stride, pad, in_h * in_w, out_h * out_w, flag_type + batchsize, cin, kernel, stride, pad, in_h * in_w, out_h * out_w, + flag_type ] elif ('reshape' in op_type or 'scale' in op_type): @@ -160,8 +163,8 @@ def get_features_from_paramkey(param_key, op_type, data_type): '(' ')').split(',') - # inputs[4], ouputs[4] - features = [0, 0, 0, 0, 0, 0, 0, 0] + # inputs[4], ouputs[4]/[5] + features = [0, 0, 0, 0, 0, 0, 0, 0, 0] for i in range(len(inputs)): if inputs[i] == '': continue @@ -175,22 +178,28 @@ def get_features_from_paramkey(param_key, op_type, data_type): 'leaky_relu' in op_type or 'tanh' in op_type or 'swish' in op_type or 'softmax' in op_type or 'hard_sigmoid' in op_type or 'sigmoid' in op_type or 'gelu' in op_type or 'clip' in op_type or - 'shape' in op_type or 'transpose' in op_type or - 'interp_v2' in op_type): + 'shape' in op_type or 'interp_v2' in op_type): inputs = re.search(r'in=(\((-?\d+,* *)+\))', param_key).group().split('=')[-1].strip( '(' ')').split(', ') - #cin, h, w - cin = int(inputs[1]) - in_h = 0 - in_w = 0 - if len(inputs) == 4: - in_h = int(inputs[2]) - in_w = int(inputs[3]) - features = [cin, in_h, in_w] + # N, C, H, W + features = [0, 0, 0, 0] + for i in range(len(inputs)): + features[i] = int(inputs[i]) + + elif 'transpose' in op_type: + inputs = re.search(r'in=(\((-?\d+,* *)+\))', + param_key).group().split('=')[-1].strip( + '(' + ')').split(', ') + + # inputs[4]/[5] + features = [0, 0, 0, 0, 0] + for i in range(len(inputs)): + features[i] = int(inputs[i]) elif 'elementwise' in op_type: X = re.search(r'X=\((-?\d+,* *)+\)', @@ -200,9 +209,9 @@ def get_features_from_paramkey(param_key, op_type, data_type): Y = re.search(r'Y=\((-?\d+,* *)+\)', param_key).group().split('=')[-1].strip('(' ')').split(',') - # X[1] X[2] X[3] Y[1] Y[2] Y[3] - features = [0, 0, 0, 0, 0, 0] - for i in range(1, len(X)): + # X[0] X[1] X[2] X[3] Y[1] Y[2] Y[3] + features = [0, 0, 0, 0, 0, 0, 0] + for i in range(len(X)): if X[i] == '': continue features[i - 1] = int(X[i]) @@ -211,7 +220,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): continue if Y[i] == '': continue - features[i + 2] = int(Y[i]) + features[i + 3] = int(Y[i]) elif 'concat' in op_type: inputs = re.search(r'in=(\((-?\d+,* *)+\))+', @@ -222,16 +231,17 @@ def get_features_from_paramkey(param_key, op_type, data_type): channels = [] for ins in inputs: channels.append(int(ins.split(', ')[1])) - #hw, c1,c2,c3,c4,c5,c6,c7,c8,c9 - features = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # bs, hw, c1,c2,c3,c4,c5,c6,c7,c8,c9 + features = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] input1 = inputs[0].split(', ') + features[0] = int(input1[0]) if len(input1) == 3: - features[0] = int(input1[2]) + features[1] = int(input1[2]) else: - features[0] = int(input1[2]) * int(input1[3]) + features[1] = int(input1[2]) * int(input1[3]) for i in range(len(channels)): - features[i + 1] = channels[i] + features[i + 2] = channels[i] elif 'yolo_box' in op_type: outputs = re.search(r'out=(\(-?\d*, \d*, \d*\))', @@ -242,7 +252,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): param_key).group().split('=')[-1].strip( '(' ')').split(', ') - + batchsize = int(inputs[0]) cin = int(inputs[1]) h = int(inputs[2]) w = int(inputs[3]) @@ -250,19 +260,19 @@ def get_features_from_paramkey(param_key, op_type, data_type): class_num = int( re.search(r'class_num=\d*', param_key).group().split('=')[-1]) - features = [cin, h * w, cout, class_num] + features = [batchsize, cin, h * w, cout, class_num] elif 'prior_box' in op_type: inputs = re.search(r'in=\((-?\d+,* *)+\)', param_key).group().split('=')[-1].strip( '(' ')').split(',') - + batchsize = int(inputs[0]) cin = int(inputs[1]) h = int(inputs[2]) w = int(inputs[3]) - features = [cin, h, w] + features = [batchsize, cin, h, w] elif 'slice' in op_type: inputs = re.search(r'in=\((-?\d+,* *)+\)', @@ -298,40 +308,30 @@ def get_features_from_paramkey(param_key, op_type, data_type): continue features[i] = int(inputs[i]) - elif 'fc' in op_type: - weight = re.search(r'weight=(\(\d*, \d*, \d*, \d*\))', - param_key).group().split('=')[-1].strip( - '(' - ')').split(', ') - - cin = int(weight[0]) - cout = int(weight[1]) - flops, params = cal_flops_params('fc', cin, cout) - - features = [cin, cout, flops, params] - elif 'shuffle_channel' in op_type: inputs = re.search(r'in=(\(-*\d*, \d*, \d*, \d*\))', param_key).group().split('=')[-1].strip( '(' ')').split(', ') + batchsize = int(inputs[0]) cin = int(inputs[1]) in_h = int(inputs[2]) in_w = int(inputs[3]) group = int(re.search(r'group=\d*', param_key).group().split('=')[1]) - features = [cin, in_h, in_w, group] + features = [batchsize, cin, in_h, in_w, group] elif 'split' in op_type: - inputs = re.search(r'in=(\(-*\d*, \d*, \d*, \d*\))', + inputs = re.search(r'in=(\(-*\d*, \d*, -?\d*, -?\d*\))', param_key).group().split('=')[-1].strip( '(' ')').split(', ') + batchsize = int(inputs[0]) cin = int(inputs[1]) in_h = int(inputs[2]) in_w = int(inputs[2]) - features = [cin, in_h, in_w] + features = [batchsize, cin, in_h, in_w] elif 'squeeze' in op_type: inputs = re.search(r'in=\((-?\d+,* *)+\)', @@ -350,7 +350,9 @@ def get_features_from_paramkey(param_key, op_type, data_type): '(' ')').split(', ') - features = [int(inputs[1]), int(inputs[2]), int(inputs[3])] + features = [ + int(inputs[0]), int(inputs[1]), int(inputs[2]), int(inputs[3]) + ] elif ('calib' in op_type or 'floor' in op_type): inputs = re.search(r'in=\((-?\d+,* *)+\)', @@ -361,12 +363,12 @@ def get_features_from_paramkey(param_key, op_type, data_type): param_key).group().split('=')[-1].strip( '(' ')').split(',') - - features = [0, 0, 0, 0, 0, 0] - for i in range(1, len(inputs)): - features[i - 1] = int(inputs[i]) - for i in range(1, len(outputs)): - features[i + 2] = int(outputs[i]) + # inputs[4] outputs[4] + features = [0, 0, 0, 0, 0, 0, 0, 0] + for i in range(len(inputs)): + features[i] = int(inputs[i]) + for i in range(len(outputs)): + features[i + 4] = int(outputs[i]) elif 'uniform_random' in op_type: shape = re.search(r'shape=\[(-?\d+,* *)+\]', @@ -379,4 +381,69 @@ def get_features_from_paramkey(param_key, op_type, data_type): continue features[i] = int(shape[i]) + elif 'arg_max' in op_type: + inputs = re.search(r'in=\((-?\d+,* *)+\)', + param_key).group().split('=')[-1].strip( + '(' + ')').split(',') + outputs = re.search(r'out=\((-?\d+,* *)+\)', + param_key).group().split('=')[-1].strip( + '(' + ')').split(',') + + # inputs[4], outputs[4] + features = [0, 0, 0, 0, 0, 0, 0, 0] + for i in range(len(inputs)): + if inputs[i] == '': + continue + features[i] = int(inputs[i]) + for i in range(len(outputs)): + if outputs[i] == '': + continue + features[i + 4] = int(outputs[i]) + + elif 'fill_constant_batch_size_like' in op_type: + inputs = re.search(r'in=\((-?\d+,* *)+\)', + param_key).group().split('=')[-1].strip( + '(' + ')').split(',') + outputs = re.search(r'out=\((-?\d+,* *)+\)', + param_key).group().split('=')[-1].strip( + '(' + ')').split(',') + + # inputs[4], outputs[4] + features = [0, 0, 0, 0, 0, 0, 0, 0] + for i in range(len(inputs)): + if inputs[i] == '': + continue + features[i] = int(inputs[i]) + for i in range(len(outputs)): + if outputs[i] == '': + continue + features[i + 4] = int(outputs[i]) + + elif op_type == 'rnn': + inputs = re.search(r'in=\((-?\d+,* *)+\)', + param_key).group().split('=')[-1].strip( + '(' + ')').split(',') + inputs[0], inputs[1] = inputs[1], inputs[0] + outputs = re.search(r'out=\((-?\d+,* *)+\)', + param_key).group().split('=')[-1].strip( + '(' + ')').split(',') + outputs[0], outputs[1] = outputs[1], outputs[0] + + # inputs[3], outputs[3] + features = [0, 0, 0, 0, 0, 0] + for i in range(len(inputs)): + if inputs[i] == '': + continue + features[i] = int(inputs[i]) + for i in range(len(outputs)): + if outputs[i] == '': + continue + features[i + 3] = int(outputs[i]) + return features diff --git a/paddleslim/analysis/latency_predictor.py b/paddleslim/analysis/latency_predictor.py index a2cf5e7bcdf71a1bf2981d3fd3bf7c43a854f656..579c29eb05c584bf3a5e932829b586b549fb6943 100644 --- a/paddleslim/analysis/latency_predictor.py +++ b/paddleslim/analysis/latency_predictor.py @@ -23,9 +23,18 @@ from .extract_features import get_data_from_tables, get_features_from_paramkey from ._utils import opt_model, load_predictor, nearest_interpolate import paddle import paddleslim +import warnings __all__ = ["LatencyPredictor", "TableLatencyPredictor"] +def format_Warning(message, category, filename, lineno, line=''): + return str(filename) + ':' + str( + lineno) + ': ' + category.__name__ + ': ' + str(message) + '\n' + + +warnings.formatwarning = format_Warning + + class LatencyPredictor(object): """Base class of latency predictor. """ @@ -53,7 +62,7 @@ class TableLatencyPredictor(LatencyPredictor): """The preditor used to get pbmodel's latency on some devices and infer engines. Args: - table_file(str): The path of file that records the devices latency of operators. + table_file(str): The path of file that records the device latency of operators. """ def __init__(self, table_file='SD710'): @@ -78,7 +87,7 @@ class TableLatencyPredictor(LatencyPredictor): assert os.path.exists( self.table_file - ), f'{self.table_file} is not existed. If you want to use our table files, please set \'table_file\' in [SD625, SD710, SD845, SD865]' + ), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in [SD625, SD710, SD845, SD865]' with open(self.table_file, 'rb') as f: self.table_dict = pickle.load(f) @@ -95,7 +104,7 @@ class TableLatencyPredictor(LatencyPredictor): with open(self.table_file, 'rb') as f: self.table_dict = pickle.load(f) - print('Successfully load {}'.format(self.table_file)) + print('Successfully loaded {}'.format(self.table_file)) def _get_input_shape(self, graph): in_shape = [] @@ -118,7 +127,7 @@ class TableLatencyPredictor(LatencyPredictor): model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams). data_type(str): Data type, fp32 or int8. Default : fp32 threads(int): threads num - input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for variable length input shape. + input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for input shape that has variable length. Returns: latency(float): The latency of the model. """ @@ -142,19 +151,31 @@ class TableLatencyPredictor(LatencyPredictor): if input_shape != None: ori_shape = self._get_input_shape(graph) - assert ori_shape == input_shape, "The parameter \'input_shape\' dosn't work now. The input shape is confirmed when saving the inference model" + assert ori_shape == input_shape, "The parameter \'input_shape\' dosn't work for now. The input shape is fixed when saving the inference model" latency = 0.0 + new_op = {} for op in graph.ops(): param_key = get_key_from_op(op) if param_key == '': continue + if param_key == None: + if op.type() in new_op: + new_op[op.type()] += 1 + else: + new_op.update({op.type(): 1}) + continue if param_key in self.table_dict: latency += self.table_dict[param_key] elif self.predictor_state: latency += self.op_predictor(op.type(), param_key, data_type) - else: - raise AssertionError(f'{param_key} is not in the table.') + if len(new_op) != 0: + warnings.warn( + "These ops are not currently supported. Please raise an issue in PaddleSlim if you find the CalledTimes is large enough to affect the accuracy." + ) + warnings.warn("OperatorType\tCalledTimes") + for key in new_op: + warnings.warn(f"{key.ljust(15)}\t{new_op[key]}") return latency diff --git a/paddleslim/analysis/parse_ops.py b/paddleslim/analysis/parse_ops.py index c8c98f829bb00a6c111d717ab74fc80ee03a26fa..5490428135d325fb4ed11eeac5f9a54d6a915ecc 100644 --- a/paddleslim/analysis/parse_ops.py +++ b/paddleslim/analysis/parse_ops.py @@ -234,12 +234,28 @@ def get_key_from_op(op): param_key = f'{op_type} in={in_shape} out={out_shape} paddings={paddings}' + elif op_type == 'arg_max': + in_shape = op.all_inputs()[-1].shape() + out_shape = op.all_outputs()[0].shape() + axis = op.attr('axis') + + param_key = f'{op_type} in={in_shape} axis={axis} out={out_shape}' + + elif op_type == 'fill_constant_batch_size_like': + in_shape = op.all_inputs()[-1].shape() + out_shape = op.all_outputs()[0].shape() + shape = op.attr('shape') + param_key = f'{op_type} in={in_shape} shape={shape} out={out_shape}' + + elif op_type == 'rnn': + out_shape = op.all_outputs()[1].shape() + in_shape = op.all_inputs()[0].shape() + param_key = f'{op_type} in={in_shape} out={out_shape}' + elif op_type in ['feed', 'fetch']: pass else: - print(op) - print(op._op) - raise KeyError(f'The "{op_type}" has never seen.') + param_key = None return param_key diff --git a/tests/test_latency_predictor.py b/tests/test_latency_predictor.py index 5e82b8a79f43344550ec3b64f733e61384429988..e44fdf29a728c7930926191476acedccdc2baf8d 100644 --- a/tests/test_latency_predictor.py +++ b/tests/test_latency_predictor.py @@ -140,6 +140,8 @@ class ModelCase6(paddle.nn.Layer): self.relu1 = ReLU() self.fc1 = paddle.nn.Linear(3 * 16 * 16, 3 * 16 * 16) self.dp = paddle.nn.Dropout(p=0.5) + self.lstm = paddle.nn.LSTM( + 1536, 10, direction='bidirectional', num_layers=2) def forward(self, inputs): x = self.bn1(inputs) @@ -149,17 +151,24 @@ class ModelCase6(paddle.nn.Layer): x = self.relu1(x) y = paddle.fluid.layers.fill_constant( x.shape, dtype=paddle.float32, value=1) - x = paddle.stack([x, y], axis=3) + # x = paddle.stack([x, y], axis=3) x = paddle.slice(x, axes=[0], starts=[0], ends=[1]) x = paddle.exp(x) - y += paddle.fluid.layers.uniform_random(y.shape) + # y += paddle.fluid.layers.uniform_random(y.shape) y = paddle.expand(y, shape=[1, 768, 768, 2]) x = paddle.expand(x, shape=[1, 768, 768, 2]) out = paddle.concat([x, y]) out = self.dp(out) out = channel_shuffle(out, 2) out1, out2 = paddle.split(out, num_or_sections=2, axis=1) - return out1, out2 + outshape = out1.shape + max_idx = paddle.argmax( + out1.reshape((outshape[0], outshape[1], outshape[2] * outshape[3])), + axis=-1) + out2 = out2.reshape( + (outshape[0], outshape[1], outshape[2] * outshape[3])) + res, _ = self.lstm(out2) + return res, max_idx class ModelCase7(paddle.nn.Layer):