提交 9e4b1355 编写于 作者: W wanghaoshuang

Add api and doc for table latency.

上级 0a582fb9
......@@ -139,3 +139,31 @@ with fluid.program_guard(main_program, startup_program):
print("FLOPS: {}".format(model_size(main_program)))
```
## TableLatencyEvaluator
>paddleslim.analysis.TableLatencyEvaluator(table_file, delimiter=",") [源代码]()
基于硬件延时表的模型延时评估器。
**参数:**
- **table_file(str):** 所使用的延时评估表的绝对路径。关于演示评估表格式请参考:[PaddleSlim硬件延时评估表格式](../paddleslim/analysis/table_latency.md)
- **delimiter(str):** 硬件延时评估表中,操作信息之前所使用的分割符,默认为英文字符逗号。
**返回值:**
- **Evaluator:** 硬件延时评估器的实例。
>paddleslim.analysis.TableLatencyEvaluator.latency(graph) [源代码]()
获得指定网络的预估延时。
**参数:**
- **graph(Program):** 待预估的目标网络。
**返回值:**
- **latency:** 目标网络的预估延时。
......@@ -15,6 +15,9 @@ import flops as flops_module
from flops import *
import model_size as model_size_module
from model_size import *
import lantency as lantency_module
from lantency import *
__all__ = []
__all__ += flops_module.__all__
__all__ += model_size_module.__all__
__all__ += lantency_module.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["LatencyEvaluator", "TableLatencyEvaluator"]
class LatencyEvaluator(object):
def __init__(self):
pass
def latency(self, graph):
pass
def _get_ops_from_graph(self, graph):
assert isinstance(graph, GraphWrapper)
ops = []
i = 0
for op in graph.ops():
if op.type() in ['conv2d', 'depthwise_conv2d']:
tmp = _conv_op_args(op)
elif op.type() in [
'elementwise_add', 'elementwise_mul', 'elementwise_max'
]:
tmp = _eltwise_op_args(op)
elif op.type() in [
'relu', 'prelu', 'sigmoid', 'relu6', 'elu', 'brelu',
'leaky_relu'
]:
tmp = _activation_op_args(op)
elif op.type() == 'batch_norm':
tmp = _batch_norm_op_args(op)
elif op.type() == 'pool2d':
tmp = _pooling_op_args(op)
elif op.type() == 'batch_norm':
tmp = _batch_norm_op_args(op)
elif op.type() == 'softmax':
tmp = _softmax_op_args(op)
elif op.type() == 'mul':
tmp = _fc_op_args(op)
else:
tmp = None
if tmp:
ops.append(tmp)
return ops
def _conv_op_args(op):
assert isinstance(op, OpWrapper)
tmp, res = [], []
# op_name
tmp.append('conv')
# flag_bias
if len(op.inputs('Bias')) == 0:
tmp.append(0)
else:
tmp.append(1)
# flag_relu
tmp.append(int(op.attr('fuse_relu')))
# batch size
tmp.append(1)
# channels, height, width
in_shapes = op.inputs('Input')[0].shape
tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])]
# output channels
w_shapes = op.inputs('Filter')[0].shape
tmp.append(int(w_shapes[0]))
# group
tmp.append(int(op.attr('groups')))
# kernel size
tmp.append(int(w_shapes[2]))
if w_shapes[2] != w_shapes[3]:
res.append(int(w_shapes[3]))
# padding
paddings = op.attr('paddings')
tmp.append(int(paddings[0]))
if paddings[0] != paddings[1]:
res.append(int(paddings[0]))
# strides
strides = op.attr('strides')
tmp.append(int(strides[0]))
if strides[0] != strides[1]:
res.append(int(strides[1]))
# dilations
dilations = op.attr('dilations')
tmp.append(int(dilations[0]))
if dilations[0] != dilations[1]:
res.append(int(dilations[1]))
tmp = tmp + res
return tmp
def _batch_norm_op_args(op):
tmp = []
# op name
tmp.append('batch_norm')
# activation type
if not op.attr('fuse_with_relu'):
tmp.append('None')
else:
tmp.append('relu')
# batch size
tmp.append(1)
# input channels, height, width
in_shapes = op.inputs("X")[0].shape
tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])]
return tmp
def _eltwise_op_args(op):
# op name
tmp = ['eltwise']
# elementwise type, TODO: add more ops
if op.type() == 'elementwise_mul':
tmp.append(1)
elif op.type() == 'elementwise_add':
tmp.append(2)
else:
tmp.append(3)
# batch size
tmp.append(1)
# input channels, height, width
in_shapes = op.inputs('X')[0].shape
while len(in_shapes) < 4:
in_shapes = in_shapes + (1, )
for i in range(1, len(in_shapes)):
tmp.append(int(in_shapes[i]))
return tmp
def _activation_op_args(op):
tmp = []
# activation type
tmp.append(op.type())
# batch size
tmp.append(1)
# input channels, height, width
in_shapes = op.inputs('X')[0].shape
while len(in_shapes) < 4:
in_shapes = in_shapes + (1, )
for i in range(1, len(in_shapes)):
tmp.append(int(in_shapes[i]))
return tmp
def _pooling_op_args(op):
tmp, res = [], []
# op name
tmp.append('pooling')
# global pooling
tmp.append(int(op.attr('global_pooling')))
# batch size
tmp.append(1)
# channels, height, width
in_shapes = op.inputs('X')[0].shape
tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])]
# kernel size
ksize = op.attr('ksize')
tmp.append(int(ksize[0]))
if ksize[0] != ksize[1]:
res.append(int(ksize[1]))
# padding
paddings = op.attr('paddings')
tmp.append(int(paddings[0]))
if paddings[0] != paddings[1]:
res.append(int(paddings[1]))
# stride
strides = op.attr('strides')
tmp.append(int(strides[0]))
if strides[0] != strides[1]:
res.append(int(strides[1]))
# ceil mode
tmp.append(int(op.attr('ceil_mode')))
# pool type
pool_type = op.attr('pooling_type')
exclusive = op.attr('exclusive')
if pool_type == 'max' and (not exclusive):
tmp.append(1)
elif pool_type == 'avg' and (not exclusive):
tmp.append(2)
else:
tmp.append(3)
tmp = tmp + res
return tmp
def _softmax_op_args(op):
# op name
tmp = ['softmax']
# axis
tmp.append(op.attr('axis'))
# batch size
tmp.append(1)
# input channels, height, width
in_shapes = op.inputs('X')[0].shape
while len(in_shapes) < 4:
in_shapes = in_shapes + (1, )
for i in range(1, len(in_shapes)):
tmp.append(int(in_shapes[i]))
return tmp
def _fc_op_args(blocks, op):
# op name
tmp = ['conv']
# flag bias
tmp.append(0)
# flag relu
tmp.append(0)
# batch size
tmp.append(1)
# input channels, height, width
channels = 1
in_shape = op.inputs('X')[0].shape
for i in range(1, len(in_shape)):
channels *= in_shape[i]
tmp = tmp + [int(channels), 1, 1]
# output channels
tmp.append(int(op.outputs('Out')[0].shape[1]))
# groups, kernel size, padding, stride, dilation
tmp = tmp + [1, 1, 0, 1, 1]
return tmp
class TableLatencyEvaluator(LatencyEvaluator):
def __init__(self, table_file, delimiter=","):
"""
The evaluator used to get graph's latency on some devices and infer engines.
Args:
- table_file(str): The path of file that records the devices latency of operators.
- delimiter(str): The delimiter used in `table_file`.
"""
self._table = self._load_table(table_file)
self._delimiter = delimiter
def _load_table(self, table_file):
table = {}
with open(table_file) as f:
line = f.readline()
self.infer_engine_name, self.device_name, self.create_time = line.strip(
).split("\t")
for line in f:
op_str, latency = line.strip().split("\t")
table[op_str] = float(latency)
return table
def _op_latency(self, op_str):
assert op_str in self._table
return self._table[op_str]
def latency(self, graph):
"""
Get latency of target graph.
Args:
- graph(GrapWrapper | Program): The graph to be evaluated.
Returns:
latency(float): The latency of given graph on current evaluator.
"""
total_latency = 0
if isinstance(graph, Program):
graph = GraphWrapper(graph)
assert isinstance(graph, GraphWrapper)
for op in self._get_ops_from_graph(graph):
total_latency += self._op_latency(self._delimiter.join(op))
return total_latency
# 硬件延时评估表
硬件延时评估表用于快速评估一个模型在特定硬件环境和推理引擎上的推理速度。
该文档主要用于定义PaddleSlim支持的硬件延时评估表的格式。
## 概述
硬件延时评估表中存放着所有可能的操作对应的延时信息,该表中的一个操作包括操作类型和操作参数,比如:操作类型可以是`conv2d`,对应的操作参数有输入特征图的大小、卷积核个数、卷积核大小等。
给定操作的延时依赖于硬件环境和推理引擎。
## 整体格式
硬件延时评估表以文件或多行字符串的形式保存。
硬件延时评估表第一行保存版本信息,后续每行为一个操作和对应的延时信息。
## 版本信息
版本信息以英文字符逗号分割,内容依次为硬件环境名称、推理引擎名称和时间戳。
- **硬件环境名称:** 用于标识硬件环境,可以包含计算架构类型、版本号等信息。
- **推理引擎名称:** 用于标识推理引擎,可以包含推理引擎名称、版本号、优化选项等信息。
- **时间戳:** 该评估表的创建时间。
## 操作信息
操作信息字段之间以逗号分割。操作信息与延迟信息之间以制表符分割。
### conv2d
**格式**
```
op_type,flag_bias,flag_relu,n_in,c_in,h_in,w_in,c_out,groups,kernel,padding,stride,dilation\tlatency
```
**字段解释**
- **op_type(str)** - 当前op类型。
- **flag_bias (int)** - 是否有 bias(0:无,1:有)。
- **flag_relu (int)** - 是否有 relu(0:无,1:有)。
- **n_in (int)** - 输入 Tensor 的批尺寸 (batch size)。
- **c_in (int)** - 输入 Tensor 的通道 (channel) 数。
- **h_in (int)** - 输入 Tensor 的特征高度。
- **w_in (int)** - 输入 Tensor 的特征宽度。
- **c_out (int)** - 输出 Tensor 的通道 (channel) 数。
- **groups (int)** - 卷积二维层(Conv2D Layer)的组数。
- **kernel (int)** - 卷积核大小。
- **padding (int)** - 填充 (padding) 大小。
- **stride (int)** - 步长 (stride) 大小。
- **dilation (int)** - 膨胀 (dilation) 大小。
- **latency (float)** - 当前op的延时时间
### activation
**格式**
```
op_type,n_in,c_in,h_in,w_in\tlatency
```
**字段解释**
- **op_type(str)** - 当前op类型。
- **n_in (int)** - 输入 Tensor 的批尺寸 (batch size)。
- **c_in (int)** - 输入 Tensor 的通道 (channel) 数。
- **h_in (int)** - 输入 Tensor 的特征高度。
- **w_in (int)** - 输入 Tensor 的特征宽度。
- **latency (float)** - 当前op的延时时间
### batch_norm
**格式**
```
op_type,active_type,n_in,c_in,h_in,w_in\tlatency
```
**字段解释**
- **op_type(str)** - 当前op类型。
- **active_type (string)** - 激活函数类型,包含:relu, prelu, sigmoid, relu6, tanh。
- **n_in (int)** - 输入 Tensor 的批尺寸 (batch size)。
- **c_in (int)** - 输入 Tensor 的通道 (channel) 数。
- **h_in (int)** - 输入 Tensor 的特征高度。
- **w_in (int)** - 输入 Tensor 的特征宽度。
- **latency (float)** - 当前op的延时时间
### eltwise
**格式**
```
op_type,n_in,c_in,h_in,w_in\tlatency
```
**字段解释**
- **op_type(str)** - 当前op类型。
- **n_in (int)** - 输入 Tensor 的批尺寸 (batch size)。
- **c_in (int)** - 输入 Tensor 的通道 (channel) 数。
- **h_in (int)** - 输入 Tensor 的特征高度。
- **w_in (int)** - 输入 Tensor 的特征宽度。
- **latency (float)** - 当前op的延时时间
### pooling
**格式**
```
op_type,flag_global_pooling,n_in,c_in,h_in,w_in,kernel,padding,stride,ceil_mode,pool_type\tlatency
```
**字段解释**
- **op_type(str)** - 当前op类型。
- **flag_global_pooling (int)** - 是否为全局池化(0:不是,1:是)。
- **n_in (int)** - 输入 Tensor 的批尺寸 (batch size)。
- **c_in (int)** - 输入 Tensor 的通道 (channel) 数。
- **h_in (int)** - 输入 Tensor 的特征高度。
- **w_in (int)** - 输入 Tensor 的特征宽度。
- **kernel (int)** - 卷积核大小。
- **padding (int)** - 填充 (padding) 大小。
- **stride (int)** - 步长 (stride) 大小。
- **ceil_mode (int)** - 是否用 ceil 函数计算输出高度和宽度。0 表示使用 floor 函数,1 表示使用 ceil 函数。
- **pool_type (int)** - 池化类型,其中 1 表示 pooling_max,2 表示 pooling_average_include_padding,3 表示 pooling_average_exclude_padding。
- **latency (float)** - 当前op的延时时间
### softmax
**格式**
```
op_type,axis,n_in,c_in,h_in,w_in\tlatency
```
**字段解释**
- **op_type(str)** - 当前op类型。
- **axis (int)** - 执行 softmax 计算的维度索引,应该在 [−1,rank − 1] 范围内,其中 rank 是输入变量的秩。
- **n_in (int)** - 输入 Tensor 的批尺寸 (batch size)。
- **c_in (int)** - 输入 Tensor 的通道 (channel) 数。
- **h_in (int)** - 输入 Tensor 的特征高度。
- **w_in (int)** - 输入 Tensor 的特征宽度。
- **latency (float)** - 当前op的延时时间
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册