未验证 提交 0706349b 编写于 作者: Z zhouzj 提交者: GitHub

[LatencyPredictor] Add some ops for ocr and tinypose. (#991)

上级 f89fc93c
...@@ -37,9 +37,9 @@ tar -xf mobilenetv1.tar ...@@ -37,9 +37,9 @@ tar -xf mobilenetv1.tar
### 2.2 预估推理延时 ### 2.2 预估推理延时
构造 TableLatencyPredictor 类实例,并调用 predict 函数预估推理模型的延时。 构造 TableLatencyPredictor 类实例,并调用 predict 函数预估推理模型的延时。
``` ```
import paddleslim from paddleslim.analysis import TableLatencyPredictor
predictor = paddleslim.TableLatencyPredictor(table_file='SD710') predictor = TableLatencyPredictor(table_file='SD710')
latency = predictor.predict(model_file='mobilenetv1_fp32.pdmodel', param_file='mobilenetv1_fp32.pdiparams, data_type='fp32') latency = predictor.predict(model_file='mobilenetv1_fp32.pdmodel', param_file='mobilenetv1_fp32.pdiparams, data_type='fp32')
print('predicted latency = {}ms'.format(latency)) print('predicted latency = {}ms'.format(latency))
``` ```
...@@ -58,9 +58,9 @@ print('predicted latency = {}ms'.format(latency)) ...@@ -58,9 +58,9 @@ print('predicted latency = {}ms'.format(latency))
### 3.2 支持预估 INT8 模型 ### 3.2 支持预估 INT8 模型
延时预估器支持对 INT8 量化模型进行延时预估,仅需提供 INT8 量化保存的推理模型文件,并将在调用 predict 函数时,设置 data_type='int8',如下所示: 延时预估器支持对 INT8 量化模型进行延时预估,仅需提供 INT8 量化保存的推理模型文件,并将在调用 predict 函数时,设置 data_type='int8',如下所示:
``` ```
import paddleslim from paddleslim.analysis import TableLatencyPredictor
predictor = paddleslim.TableLatencyPredictor(table_file='SD710') predictor = TableLatencyPredictor(table_file='SD710')
predictor.predict(model_file='mobilenetv1_int8.pdmodel', param_file='mobilenetv1_int8.pdiparams, data_type='int8') predictor.predict(model_file='mobilenetv1_int8.pdmodel', param_file='mobilenetv1_int8.pdiparams, data_type='int8')
``` ```
......
...@@ -31,7 +31,7 @@ def opt_model(opt="paddle_lite_opt", ...@@ -31,7 +31,7 @@ def opt_model(opt="paddle_lite_opt",
optimize_out_type='protobuf', optimize_out_type='protobuf',
valid_targets='arm'): valid_targets='arm'):
assert os.path.exists(model_file) and os.path.exists( assert os.path.exists(model_file) and os.path.exists(
param_file), f'{model_file} or {param_file} is not existed.' param_file), f'{model_file} or {param_file} does not exist.'
save_dir = f'./opt_models_tmp/{os.getpid()}' save_dir = f'./opt_models_tmp/{os.getpid()}'
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
...@@ -225,11 +225,11 @@ def nearest_interpolate(features, data): ...@@ -225,11 +225,11 @@ def nearest_interpolate(features, data):
return latency[idx] return latency[idx]
def dowload_predictor(op_dir, op): def download_predictor(op_dir, op):
"""Dowload op predictors' model file """Download op predictors' model file
Args: Args:
op_dir(str): the dowload path of op predictor. Actually, it's the hardware information. op_dir(str): the path to op predictor. Actually, it's the hardware information.
op(str): the op type. op(str): the op type.
Returns: Returns:
op_path: The path of the file. op_path: The path of the file.
...@@ -252,7 +252,7 @@ def load_predictor(op_type, op_dir, data_type='fp32'): ...@@ -252,7 +252,7 @@ def load_predictor(op_type, op_dir, data_type='fp32'):
elif 'matmul' in op_type: elif 'matmul' in op_type:
op = 'matmul' op = 'matmul'
op_path = dowload_predictor(op_dir, op) op_path = download_predictor(op_dir, op)
with open(op_path, 'rb') as f: with open(op_path, 'rb') as f:
model = pickle.load(f) model = pickle.load(f)
......
...@@ -38,7 +38,9 @@ def get_data_from_tables(table_dict, op_type, data_type='fp32'): ...@@ -38,7 +38,9 @@ def get_data_from_tables(table_dict, op_type, data_type='fp32'):
features = get_features_from_paramkey(param_key, op_type, data_type) features = get_features_from_paramkey(param_key, op_type, data_type)
if features == None: if features == None:
continue continue
# only support bs=1 now
if features[0] != 1:
continue
features.append(table_dict[param_key]) features.append(table_dict[param_key])
data.append(features) data.append(features)
return np.array(data) return np.array(data)
...@@ -47,7 +49,7 @@ def get_data_from_tables(table_dict, op_type, data_type='fp32'): ...@@ -47,7 +49,7 @@ def get_data_from_tables(table_dict, op_type, data_type='fp32'):
def get_features_from_paramkey(param_key, op_type, data_type): def get_features_from_paramkey(param_key, op_type, data_type):
"""Get op's parameters according to the key of latency table """Get op's parameters according to the key of latency table
""" """
features = [] features = None
if 'conv2d' in op_type: if 'conv2d' in op_type:
flag_quant = 'quant=None' if data_type == 'fp32' else 'quant=True' flag_quant = 'quant=None' if data_type == 'fp32' else 'quant=True'
...@@ -58,12 +60,12 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -58,12 +60,12 @@ def get_features_from_paramkey(param_key, op_type, data_type):
param_key).group().split('=')[-1].strip( param_key).group().split('=')[-1].strip(
'(' '('
')').split(', ') ')').split(', ')
outputs = re.search(r'out=(\(-*\d*, \d*, \d*, \d*\))', outputs = re.search(r'out=(\(-*\d*, \d*, -?\d*, -?\d*\))',
param_key).group().split('=')[-1].strip( param_key).group().split('=')[-1].strip(
'(' '('
')').split(', ') ')').split(', ')
batchsize = int(outputs[0])
cout = int(weight[0]) cout = int(outputs[1])
cin = int(weight[1]) cin = int(weight[1])
kernel = int(weight[2]) kernel = int(weight[2])
out_h = int(outputs[2]) out_h = int(outputs[2])
...@@ -75,7 +77,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -75,7 +77,7 @@ def get_features_from_paramkey(param_key, op_type, data_type):
out_w) out_w)
if data_type == 'fp32': if data_type == 'fp32':
inputs = re.search(r'in=(\(-*\d*, \d*, \d*, \d*\))', inputs = re.search(r'in=(\(-*\d*, \d*, -?\d*, -?\d*\))',
param_key).group().split('=')[-1].strip( param_key).group().split('=')[-1].strip(
'(' '('
')').split(', ') ')').split(', ')
...@@ -84,20 +86,20 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -84,20 +86,20 @@ def get_features_from_paramkey(param_key, op_type, data_type):
in_w = int(inputs[3]) in_w = int(inputs[3])
features = [ features = [
in_c, cout, kernel, group, stride, pad, in_h * in_w, batchsize, in_c, cout, kernel, group, stride, pad, in_h * in_w,
out_h * out_w out_h * out_w
] ]
else: else:
features = [ features = [
cin, cout, kernel, group, stride, pad, out_h * out_w, flops, batchsize, cin, cout, kernel, group, stride, pad, out_h * out_w,
params flops, params
] ]
elif 'matmul' in op_type: elif 'matmul' in op_type:
X = re.search(r'X=(\(-*\d*, \d*\))', X = re.search(r'X=(\((-?\d+,* *)+\))',
param_key).group().split('=')[-1].strip('(' param_key).group().split('=')[-1].strip('('
')').split(', ') ')').split(', ')
Y = re.search(r'Y=(\(\d*, \d*\))', Y = re.search(r'Y=(\((-?\d+,* *)+\))',
param_key).group().split('=')[-1].strip('(' param_key).group().split('=')[-1].strip('('
')').split(', ') ')').split(', ')
...@@ -106,7 +108,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -106,7 +108,7 @@ def get_features_from_paramkey(param_key, op_type, data_type):
c = int(Y[1]) c = int(Y[1])
flops, params = cal_flops_params('fc', b, c) flops, params = cal_flops_params('fc', b, c)
features = [b, c, flops, params] features = [a, b, c, flops, params]
elif ('batch_norm' in op_type or 'layer_norm' in op_type): elif ('batch_norm' in op_type or 'layer_norm' in op_type):
inputs = re.search(r'in=(\((-?\d+,* *)+\))', inputs = re.search(r'in=(\((-?\d+,* *)+\))',
...@@ -114,11 +116,11 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -114,11 +116,11 @@ def get_features_from_paramkey(param_key, op_type, data_type):
'(' '('
')').split(', ') ')').split(', ')
features = [0, 0, 0] features = [0, 0, 0, 0]
for i in range(1, len(inputs)): for i in range(len(inputs)):
if inputs[i] == '': if inputs[i] == '':
continue continue
features[i - 1] = int(inputs[i]) features[i] = int(inputs[i])
elif 'pool2d' in op_type: elif 'pool2d' in op_type:
...@@ -130,7 +132,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -130,7 +132,7 @@ def get_features_from_paramkey(param_key, op_type, data_type):
param_key).group().split('=')[-1].strip( param_key).group().split('=')[-1].strip(
'(' '('
')').split(', ') ')').split(', ')
batchsize = int(inputs[0])
cin = int(inputs[1]) cin = int(inputs[1])
in_h = int(inputs[2]) in_h = int(inputs[2])
in_w = int(inputs[3]) in_w = int(inputs[3])
...@@ -147,7 +149,8 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -147,7 +149,8 @@ def get_features_from_paramkey(param_key, op_type, data_type):
flag_type = 1 if 'type=avg' in param_key else 0 flag_type = 1 if 'type=avg' in param_key else 0
features = [ features = [
cin, kernel, stride, pad, in_h * in_w, out_h * out_w, flag_type batchsize, cin, kernel, stride, pad, in_h * in_w, out_h * out_w,
flag_type
] ]
elif ('reshape' in op_type or 'scale' in op_type): elif ('reshape' in op_type or 'scale' in op_type):
...@@ -160,8 +163,8 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -160,8 +163,8 @@ def get_features_from_paramkey(param_key, op_type, data_type):
'(' '('
')').split(',') ')').split(',')
# inputs[4], ouputs[4] # inputs[4], ouputs[4]/[5]
features = [0, 0, 0, 0, 0, 0, 0, 0] features = [0, 0, 0, 0, 0, 0, 0, 0, 0]
for i in range(len(inputs)): for i in range(len(inputs)):
if inputs[i] == '': if inputs[i] == '':
continue continue
...@@ -175,22 +178,28 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -175,22 +178,28 @@ def get_features_from_paramkey(param_key, op_type, data_type):
'leaky_relu' in op_type or 'tanh' in op_type or 'swish' in op_type or 'leaky_relu' in op_type or 'tanh' in op_type or 'swish' in op_type or
'softmax' in op_type or 'hard_sigmoid' in op_type or 'softmax' in op_type or 'hard_sigmoid' in op_type or
'sigmoid' in op_type or 'gelu' in op_type or 'clip' in op_type or 'sigmoid' in op_type or 'gelu' in op_type or 'clip' in op_type or
'shape' in op_type or 'transpose' in op_type or 'shape' in op_type or 'interp_v2' in op_type):
'interp_v2' in op_type):
inputs = re.search(r'in=(\((-?\d+,* *)+\))', inputs = re.search(r'in=(\((-?\d+,* *)+\))',
param_key).group().split('=')[-1].strip( param_key).group().split('=')[-1].strip(
'(' '('
')').split(', ') ')').split(', ')
#cin, h, w
cin = int(inputs[1])
in_h = 0
in_w = 0
if len(inputs) == 4:
in_h = int(inputs[2])
in_w = int(inputs[3])
features = [cin, in_h, in_w] # N, C, H, W
features = [0, 0, 0, 0]
for i in range(len(inputs)):
features[i] = int(inputs[i])
elif 'transpose' in op_type:
inputs = re.search(r'in=(\((-?\d+,* *)+\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
# inputs[4]/[5]
features = [0, 0, 0, 0, 0]
for i in range(len(inputs)):
features[i] = int(inputs[i])
elif 'elementwise' in op_type: elif 'elementwise' in op_type:
X = re.search(r'X=\((-?\d+,* *)+\)', X = re.search(r'X=\((-?\d+,* *)+\)',
...@@ -200,9 +209,9 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -200,9 +209,9 @@ def get_features_from_paramkey(param_key, op_type, data_type):
Y = re.search(r'Y=\((-?\d+,* *)+\)', Y = re.search(r'Y=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip('(' param_key).group().split('=')[-1].strip('('
')').split(',') ')').split(',')
# X[1] X[2] X[3] Y[1] Y[2] Y[3] # X[0] X[1] X[2] X[3] Y[1] Y[2] Y[3]
features = [0, 0, 0, 0, 0, 0] features = [0, 0, 0, 0, 0, 0, 0]
for i in range(1, len(X)): for i in range(len(X)):
if X[i] == '': if X[i] == '':
continue continue
features[i - 1] = int(X[i]) features[i - 1] = int(X[i])
...@@ -211,7 +220,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -211,7 +220,7 @@ def get_features_from_paramkey(param_key, op_type, data_type):
continue continue
if Y[i] == '': if Y[i] == '':
continue continue
features[i + 2] = int(Y[i]) features[i + 3] = int(Y[i])
elif 'concat' in op_type: elif 'concat' in op_type:
inputs = re.search(r'in=(\((-?\d+,* *)+\))+', inputs = re.search(r'in=(\((-?\d+,* *)+\))+',
...@@ -222,16 +231,17 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -222,16 +231,17 @@ def get_features_from_paramkey(param_key, op_type, data_type):
channels = [] channels = []
for ins in inputs: for ins in inputs:
channels.append(int(ins.split(', ')[1])) channels.append(int(ins.split(', ')[1]))
#hw, c1,c2,c3,c4,c5,c6,c7,c8,c9 # bs, hw, c1,c2,c3,c4,c5,c6,c7,c8,c9
features = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] features = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
input1 = inputs[0].split(', ') input1 = inputs[0].split(', ')
features[0] = int(input1[0])
if len(input1) == 3: if len(input1) == 3:
features[0] = int(input1[2]) features[1] = int(input1[2])
else: else:
features[0] = int(input1[2]) * int(input1[3]) features[1] = int(input1[2]) * int(input1[3])
for i in range(len(channels)): for i in range(len(channels)):
features[i + 1] = channels[i] features[i + 2] = channels[i]
elif 'yolo_box' in op_type: elif 'yolo_box' in op_type:
outputs = re.search(r'out=(\(-?\d*, \d*, \d*\))', outputs = re.search(r'out=(\(-?\d*, \d*, \d*\))',
...@@ -242,7 +252,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -242,7 +252,7 @@ def get_features_from_paramkey(param_key, op_type, data_type):
param_key).group().split('=')[-1].strip( param_key).group().split('=')[-1].strip(
'(' '('
')').split(', ') ')').split(', ')
batchsize = int(inputs[0])
cin = int(inputs[1]) cin = int(inputs[1])
h = int(inputs[2]) h = int(inputs[2])
w = int(inputs[3]) w = int(inputs[3])
...@@ -250,19 +260,19 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -250,19 +260,19 @@ def get_features_from_paramkey(param_key, op_type, data_type):
class_num = int( class_num = int(
re.search(r'class_num=\d*', param_key).group().split('=')[-1]) re.search(r'class_num=\d*', param_key).group().split('=')[-1])
features = [cin, h * w, cout, class_num] features = [batchsize, cin, h * w, cout, class_num]
elif 'prior_box' in op_type: elif 'prior_box' in op_type:
inputs = re.search(r'in=\((-?\d+,* *)+\)', inputs = re.search(r'in=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip( param_key).group().split('=')[-1].strip(
'(' '('
')').split(',') ')').split(',')
batchsize = int(inputs[0])
cin = int(inputs[1]) cin = int(inputs[1])
h = int(inputs[2]) h = int(inputs[2])
w = int(inputs[3]) w = int(inputs[3])
features = [cin, h, w] features = [batchsize, cin, h, w]
elif 'slice' in op_type: elif 'slice' in op_type:
inputs = re.search(r'in=\((-?\d+,* *)+\)', inputs = re.search(r'in=\((-?\d+,* *)+\)',
...@@ -298,40 +308,30 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -298,40 +308,30 @@ def get_features_from_paramkey(param_key, op_type, data_type):
continue continue
features[i] = int(inputs[i]) features[i] = int(inputs[i])
elif 'fc' in op_type:
weight = re.search(r'weight=(\(\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
cin = int(weight[0])
cout = int(weight[1])
flops, params = cal_flops_params('fc', cin, cout)
features = [cin, cout, flops, params]
elif 'shuffle_channel' in op_type: elif 'shuffle_channel' in op_type:
inputs = re.search(r'in=(\(-*\d*, \d*, \d*, \d*\))', inputs = re.search(r'in=(\(-*\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip( param_key).group().split('=')[-1].strip(
'(' '('
')').split(', ') ')').split(', ')
batchsize = int(inputs[0])
cin = int(inputs[1]) cin = int(inputs[1])
in_h = int(inputs[2]) in_h = int(inputs[2])
in_w = int(inputs[3]) in_w = int(inputs[3])
group = int(re.search(r'group=\d*', param_key).group().split('=')[1]) group = int(re.search(r'group=\d*', param_key).group().split('=')[1])
features = [cin, in_h, in_w, group] features = [batchsize, cin, in_h, in_w, group]
elif 'split' in op_type: elif 'split' in op_type:
inputs = re.search(r'in=(\(-*\d*, \d*, \d*, \d*\))', inputs = re.search(r'in=(\(-*\d*, \d*, -?\d*, -?\d*\))',
param_key).group().split('=')[-1].strip( param_key).group().split('=')[-1].strip(
'(' '('
')').split(', ') ')').split(', ')
batchsize = int(inputs[0])
cin = int(inputs[1]) cin = int(inputs[1])
in_h = int(inputs[2]) in_h = int(inputs[2])
in_w = int(inputs[2]) in_w = int(inputs[2])
features = [cin, in_h, in_w] features = [batchsize, cin, in_h, in_w]
elif 'squeeze' in op_type: elif 'squeeze' in op_type:
inputs = re.search(r'in=\((-?\d+,* *)+\)', inputs = re.search(r'in=\((-?\d+,* *)+\)',
...@@ -350,7 +350,9 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -350,7 +350,9 @@ def get_features_from_paramkey(param_key, op_type, data_type):
'(' '('
')').split(', ') ')').split(', ')
features = [int(inputs[1]), int(inputs[2]), int(inputs[3])] features = [
int(inputs[0]), int(inputs[1]), int(inputs[2]), int(inputs[3])
]
elif ('calib' in op_type or 'floor' in op_type): elif ('calib' in op_type or 'floor' in op_type):
inputs = re.search(r'in=\((-?\d+,* *)+\)', inputs = re.search(r'in=\((-?\d+,* *)+\)',
...@@ -361,12 +363,12 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -361,12 +363,12 @@ def get_features_from_paramkey(param_key, op_type, data_type):
param_key).group().split('=')[-1].strip( param_key).group().split('=')[-1].strip(
'(' '('
')').split(',') ')').split(',')
# inputs[4] outputs[4]
features = [0, 0, 0, 0, 0, 0] features = [0, 0, 0, 0, 0, 0, 0, 0]
for i in range(1, len(inputs)): for i in range(len(inputs)):
features[i - 1] = int(inputs[i]) features[i] = int(inputs[i])
for i in range(1, len(outputs)): for i in range(len(outputs)):
features[i + 2] = int(outputs[i]) features[i + 4] = int(outputs[i])
elif 'uniform_random' in op_type: elif 'uniform_random' in op_type:
shape = re.search(r'shape=\[(-?\d+,* *)+\]', shape = re.search(r'shape=\[(-?\d+,* *)+\]',
...@@ -379,4 +381,69 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -379,4 +381,69 @@ def get_features_from_paramkey(param_key, op_type, data_type):
continue continue
features[i] = int(shape[i]) features[i] = int(shape[i])
elif 'arg_max' in op_type:
inputs = re.search(r'in=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
outputs = re.search(r'out=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
# inputs[4], outputs[4]
features = [0, 0, 0, 0, 0, 0, 0, 0]
for i in range(len(inputs)):
if inputs[i] == '':
continue
features[i] = int(inputs[i])
for i in range(len(outputs)):
if outputs[i] == '':
continue
features[i + 4] = int(outputs[i])
elif 'fill_constant_batch_size_like' in op_type:
inputs = re.search(r'in=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
outputs = re.search(r'out=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
# inputs[4], outputs[4]
features = [0, 0, 0, 0, 0, 0, 0, 0]
for i in range(len(inputs)):
if inputs[i] == '':
continue
features[i] = int(inputs[i])
for i in range(len(outputs)):
if outputs[i] == '':
continue
features[i + 4] = int(outputs[i])
elif op_type == 'rnn':
inputs = re.search(r'in=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
inputs[0], inputs[1] = inputs[1], inputs[0]
outputs = re.search(r'out=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
outputs[0], outputs[1] = outputs[1], outputs[0]
# inputs[3], outputs[3]
features = [0, 0, 0, 0, 0, 0]
for i in range(len(inputs)):
if inputs[i] == '':
continue
features[i] = int(inputs[i])
for i in range(len(outputs)):
if outputs[i] == '':
continue
features[i + 3] = int(outputs[i])
return features return features
...@@ -23,9 +23,18 @@ from .extract_features import get_data_from_tables, get_features_from_paramkey ...@@ -23,9 +23,18 @@ from .extract_features import get_data_from_tables, get_features_from_paramkey
from ._utils import opt_model, load_predictor, nearest_interpolate from ._utils import opt_model, load_predictor, nearest_interpolate
import paddle import paddle
import paddleslim import paddleslim
import warnings
__all__ = ["LatencyPredictor", "TableLatencyPredictor"] __all__ = ["LatencyPredictor", "TableLatencyPredictor"]
def format_Warning(message, category, filename, lineno, line=''):
return str(filename) + ':' + str(
lineno) + ': ' + category.__name__ + ': ' + str(message) + '\n'
warnings.formatwarning = format_Warning
class LatencyPredictor(object): class LatencyPredictor(object):
"""Base class of latency predictor. """Base class of latency predictor.
""" """
...@@ -53,7 +62,7 @@ class TableLatencyPredictor(LatencyPredictor): ...@@ -53,7 +62,7 @@ class TableLatencyPredictor(LatencyPredictor):
"""The preditor used to get pbmodel's latency on some devices and infer engines. """The preditor used to get pbmodel's latency on some devices and infer engines.
Args: Args:
table_file(str): The path of file that records the devices latency of operators. table_file(str): The path of file that records the device latency of operators.
""" """
def __init__(self, table_file='SD710'): def __init__(self, table_file='SD710'):
...@@ -78,7 +87,7 @@ class TableLatencyPredictor(LatencyPredictor): ...@@ -78,7 +87,7 @@ class TableLatencyPredictor(LatencyPredictor):
assert os.path.exists( assert os.path.exists(
self.table_file self.table_file
), f'{self.table_file} is not existed. If you want to use our table files, please set \'table_file\' in [SD625, SD710, SD845, SD865]' ), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in [SD625, SD710, SD845, SD865]'
with open(self.table_file, 'rb') as f: with open(self.table_file, 'rb') as f:
self.table_dict = pickle.load(f) self.table_dict = pickle.load(f)
...@@ -95,7 +104,7 @@ class TableLatencyPredictor(LatencyPredictor): ...@@ -95,7 +104,7 @@ class TableLatencyPredictor(LatencyPredictor):
with open(self.table_file, 'rb') as f: with open(self.table_file, 'rb') as f:
self.table_dict = pickle.load(f) self.table_dict = pickle.load(f)
print('Successfully load {}'.format(self.table_file)) print('Successfully loaded {}'.format(self.table_file))
def _get_input_shape(self, graph): def _get_input_shape(self, graph):
in_shape = [] in_shape = []
...@@ -118,7 +127,7 @@ class TableLatencyPredictor(LatencyPredictor): ...@@ -118,7 +127,7 @@ class TableLatencyPredictor(LatencyPredictor):
model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams). model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams).
data_type(str): Data type, fp32 or int8. Default : fp32 data_type(str): Data type, fp32 or int8. Default : fp32
threads(int): threads num threads(int): threads num
input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for variable length input shape. input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for input shape that has variable length.
Returns: Returns:
latency(float): The latency of the model. latency(float): The latency of the model.
""" """
...@@ -142,19 +151,31 @@ class TableLatencyPredictor(LatencyPredictor): ...@@ -142,19 +151,31 @@ class TableLatencyPredictor(LatencyPredictor):
if input_shape != None: if input_shape != None:
ori_shape = self._get_input_shape(graph) ori_shape = self._get_input_shape(graph)
assert ori_shape == input_shape, "The parameter \'input_shape\' dosn't work now. The input shape is confirmed when saving the inference model" assert ori_shape == input_shape, "The parameter \'input_shape\' dosn't work for now. The input shape is fixed when saving the inference model"
latency = 0.0 latency = 0.0
new_op = {}
for op in graph.ops(): for op in graph.ops():
param_key = get_key_from_op(op) param_key = get_key_from_op(op)
if param_key == '': if param_key == '':
continue continue
if param_key == None:
if op.type() in new_op:
new_op[op.type()] += 1
else:
new_op.update({op.type(): 1})
continue
if param_key in self.table_dict: if param_key in self.table_dict:
latency += self.table_dict[param_key] latency += self.table_dict[param_key]
elif self.predictor_state: elif self.predictor_state:
latency += self.op_predictor(op.type(), param_key, data_type) latency += self.op_predictor(op.type(), param_key, data_type)
else: if len(new_op) != 0:
raise AssertionError(f'{param_key} is not in the table.') warnings.warn(
"These ops are not currently supported. Please raise an issue in PaddleSlim if you find the CalledTimes is large enough to affect the accuracy."
)
warnings.warn("OperatorType\tCalledTimes")
for key in new_op:
warnings.warn(f"{key.ljust(15)}\t{new_op[key]}")
return latency return latency
......
...@@ -234,12 +234,28 @@ def get_key_from_op(op): ...@@ -234,12 +234,28 @@ def get_key_from_op(op):
param_key = f'{op_type} in={in_shape} out={out_shape} paddings={paddings}' param_key = f'{op_type} in={in_shape} out={out_shape} paddings={paddings}'
elif op_type == 'arg_max':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
axis = op.attr('axis')
param_key = f'{op_type} in={in_shape} axis={axis} out={out_shape}'
elif op_type == 'fill_constant_batch_size_like':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
shape = op.attr('shape')
param_key = f'{op_type} in={in_shape} shape={shape} out={out_shape}'
elif op_type == 'rnn':
out_shape = op.all_outputs()[1].shape()
in_shape = op.all_inputs()[0].shape()
param_key = f'{op_type} in={in_shape} out={out_shape}'
elif op_type in ['feed', 'fetch']: elif op_type in ['feed', 'fetch']:
pass pass
else: else:
print(op) param_key = None
print(op._op)
raise KeyError(f'The "{op_type}" has never seen.')
return param_key return param_key
...@@ -140,6 +140,8 @@ class ModelCase6(paddle.nn.Layer): ...@@ -140,6 +140,8 @@ class ModelCase6(paddle.nn.Layer):
self.relu1 = ReLU() self.relu1 = ReLU()
self.fc1 = paddle.nn.Linear(3 * 16 * 16, 3 * 16 * 16) self.fc1 = paddle.nn.Linear(3 * 16 * 16, 3 * 16 * 16)
self.dp = paddle.nn.Dropout(p=0.5) self.dp = paddle.nn.Dropout(p=0.5)
self.lstm = paddle.nn.LSTM(
1536, 10, direction='bidirectional', num_layers=2)
def forward(self, inputs): def forward(self, inputs):
x = self.bn1(inputs) x = self.bn1(inputs)
...@@ -149,17 +151,24 @@ class ModelCase6(paddle.nn.Layer): ...@@ -149,17 +151,24 @@ class ModelCase6(paddle.nn.Layer):
x = self.relu1(x) x = self.relu1(x)
y = paddle.fluid.layers.fill_constant( y = paddle.fluid.layers.fill_constant(
x.shape, dtype=paddle.float32, value=1) x.shape, dtype=paddle.float32, value=1)
x = paddle.stack([x, y], axis=3) # x = paddle.stack([x, y], axis=3)
x = paddle.slice(x, axes=[0], starts=[0], ends=[1]) x = paddle.slice(x, axes=[0], starts=[0], ends=[1])
x = paddle.exp(x) x = paddle.exp(x)
y += paddle.fluid.layers.uniform_random(y.shape) # y += paddle.fluid.layers.uniform_random(y.shape)
y = paddle.expand(y, shape=[1, 768, 768, 2]) y = paddle.expand(y, shape=[1, 768, 768, 2])
x = paddle.expand(x, shape=[1, 768, 768, 2]) x = paddle.expand(x, shape=[1, 768, 768, 2])
out = paddle.concat([x, y]) out = paddle.concat([x, y])
out = self.dp(out) out = self.dp(out)
out = channel_shuffle(out, 2) out = channel_shuffle(out, 2)
out1, out2 = paddle.split(out, num_or_sections=2, axis=1) out1, out2 = paddle.split(out, num_or_sections=2, axis=1)
return out1, out2 outshape = out1.shape
max_idx = paddle.argmax(
out1.reshape((outshape[0], outshape[1], outshape[2] * outshape[3])),
axis=-1)
out2 = out2.reshape(
(outshape[0], outshape[1], outshape[2] * outshape[3]))
res, _ = self.lstm(out2)
return res, max_idx
class ModelCase7(paddle.nn.Layer): class ModelCase7(paddle.nn.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册