未验证 提交 d98c7738 编写于 作者: Z zhouzj 提交者: GitHub

[LatencyPredictor] add hardware (#1089)

* Add rk3288 predictor

* fix some bugs for sparse conv2d.
上级 d01db560
......@@ -61,7 +61,7 @@ def get_features_from_paramkey(param_key, op_type, data_type):
if quant_bits not in param_key:
return None
weight = re.search(r'weight=(\(\d*, \d*, \d*, \d*\))',
weight = re.search(r'weight=(\(\d*, -?\d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
......
......@@ -68,7 +68,7 @@ class TableLatencyPredictor(LatencyPredictor):
Args:
table_file(str): The path of file that records the device latency of operators.
"""
hardware_list = ['SD625', 'SD710']
hardware_list = ['SD625', 'SD710', 'RK3288']
def __init__(self, table_file='SD710'):
self.table_file = table_file
......
......@@ -24,7 +24,10 @@ def get_key_from_op(op):
if op_type == 'sparse_conv2d':
out_shape = op.all_outputs()[0].shape()
in_shape = op.inputs('Input')[0].shape()
weight_shape = (out_shape[1], in_shape[1], 1, 1)
if in_shape:
weight_shape = (out_shape[1], in_shape[1], 1, 1)
else:
weight_shape = (out_shape[1], -1, 1, 1)
NonZeroWeights = op.inputs('NonZeroWeights')[0].shape()[0]
stride = op.attr('strides')[1]
......@@ -147,14 +150,13 @@ def get_key_from_op(op):
elif op_type == 'stack':
data = op.all_inputs()
X = "["
X = ""
for x in data:
X += f"{x.shape()}"
X += "]"
axis = op.attr('axis')
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} X={X} axis={axis} out={out_shape}'
param_key = f'{op_type} in={X} axis={axis} out={out_shape}'
elif op_type == 'exp':
in_shape = op.all_inputs()[-1].shape()
......@@ -219,7 +221,7 @@ def get_key_from_op(op):
in_shape2 = op.all_inputs()[1].shape()
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} in={in_shape1} in={in_shape2} out={out_shape}'
param_key = f'{op_type} X={in_shape1} Y={in_shape2} out={out_shape}'
elif op_type in ['calib', 'floor']:
in_shape = op.all_inputs()[-1].shape()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册