提交 76f0d7e2 编写于 作者: C ceci3 提交者: whs

Conv only for latency (#37)

上级 50d69ece
...@@ -66,7 +66,6 @@ def _graph_flops(graph, only_conv=True, detail=False): ...@@ -66,7 +66,6 @@ def _graph_flops(graph, only_conv=True, detail=False):
y_shape = op.inputs("Y")[0].shape() y_shape = op.inputs("Y")[0].shape()
if x_shape[0] == -1: if x_shape[0] == -1:
x_shape[0] = 1 x_shape[0] = 1
flops += x_shape[0] * x_shape[1] * y_shape[1]
op_flops = x_shape[0] * x_shape[1] * y_shape[1] op_flops = x_shape[0] * x_shape[1] * y_shape[1]
flops += op_flops flops += op_flops
......
...@@ -24,7 +24,7 @@ class LatencyEvaluator(object): ...@@ -24,7 +24,7 @@ class LatencyEvaluator(object):
def latency(self, graph): def latency(self, graph):
pass pass
def _get_ops_from_graph(self, graph): def _get_ops_from_graph(self, graph, only_conv):
assert isinstance(graph, GraphWrapper) assert isinstance(graph, GraphWrapper)
ops = [] ops = []
i = 0 i = 0
...@@ -33,22 +33,20 @@ class LatencyEvaluator(object): ...@@ -33,22 +33,20 @@ class LatencyEvaluator(object):
tmp = self._conv_op_args(op) tmp = self._conv_op_args(op)
elif op.type() in [ elif op.type() in [
'elementwise_add', 'elementwise_mul', 'elementwise_max' 'elementwise_add', 'elementwise_mul', 'elementwise_max'
]: ] and only_conv == False:
tmp = self._eltwise_op_args(op) tmp = self._eltwise_op_args(op)
elif op.type() in [ elif op.type() in [
'relu', 'prelu', 'sigmoid', 'relu6', 'elu', 'brelu', 'relu', 'prelu', 'sigmoid', 'relu6', 'elu', 'brelu',
'leaky_relu' 'leaky_relu'
]: ] and only_conv == False:
tmp = self._activation_op_args(op) tmp = self._activation_op_args(op)
elif op.type() == 'batch_norm': elif op.type() == 'batch_norm' and only_conv == False:
tmp = self._batch_norm_op_args(op) tmp = self._batch_norm_op_args(op)
elif op.type() == 'pool2d': elif op.type() == 'pool2d' and only_conv == False:
tmp = self._pooling_op_args(op) tmp = self._pooling_op_args(op)
elif op.type() == 'batch_norm': elif op.type() == 'softmax' and only_conv == False:
tmp = self._batch_norm_op_args(op)
elif op.type() == 'softmax':
tmp = self._softmax_op_args(op) tmp = self._softmax_op_args(op)
elif op.type() == 'mul': elif op.type() == 'mul' and only_conv == False:
tmp = self._fc_op_args(op) tmp = self._fc_op_args(op)
else: else:
tmp = None tmp = None
...@@ -268,11 +266,12 @@ class TableLatencyEvaluator(LatencyEvaluator): ...@@ -268,11 +266,12 @@ class TableLatencyEvaluator(LatencyEvaluator):
assert op_str in self._table assert op_str in self._table
return self._table[op_str] return self._table[op_str]
def latency(self, graph): def latency(self, graph, only_conv=True):
""" """
Get latency of target graph. Get latency of target graph.
Args: Args:
- graph(GrapWrapper | Program): The graph to be evaluated. - graph(GrapWrapper | Program): The graph to be evaluated.
- only_conv(bool): only evaluated convolution layer if `only_conv` is true. Default: True.
Returns: Returns:
latency(float): The latency of given graph on current evaluator. latency(float): The latency of given graph on current evaluator.
""" """
...@@ -280,7 +279,7 @@ class TableLatencyEvaluator(LatencyEvaluator): ...@@ -280,7 +279,7 @@ class TableLatencyEvaluator(LatencyEvaluator):
if isinstance(graph, Program): if isinstance(graph, Program):
graph = GraphWrapper(graph) graph = GraphWrapper(graph)
assert isinstance(graph, GraphWrapper) assert isinstance(graph, GraphWrapper)
for op in self._get_ops_from_graph(graph): for op in self._get_ops_from_graph(graph, only_conv):
total_latency += self._op_latency( total_latency += self._op_latency(
self._delimiter.join(map(lambda x: str(x), op))) self._delimiter.join(map(lambda x: str(x), op)))
return total_latency return total_latency
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册