From 76f0d7e2a89f9a2b0dc671f88a833eea17eccf57 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Thu, 16 Jan 2020 10:54:08 +0800 Subject: [PATCH] Conv only for latency (#37) --- paddleslim/analysis/flops.py | 1 - paddleslim/analysis/latency.py | 21 ++++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py index 4e710fdc..1f01f7c9 100644 --- a/paddleslim/analysis/flops.py +++ b/paddleslim/analysis/flops.py @@ -66,7 +66,6 @@ def _graph_flops(graph, only_conv=True, detail=False): y_shape = op.inputs("Y")[0].shape() if x_shape[0] == -1: x_shape[0] = 1 - flops += x_shape[0] * x_shape[1] * y_shape[1] op_flops = x_shape[0] * x_shape[1] * y_shape[1] flops += op_flops diff --git a/paddleslim/analysis/latency.py b/paddleslim/analysis/latency.py index fea255c7..1c2fe87a 100644 --- a/paddleslim/analysis/latency.py +++ b/paddleslim/analysis/latency.py @@ -24,7 +24,7 @@ class LatencyEvaluator(object): def latency(self, graph): pass - def _get_ops_from_graph(self, graph): + def _get_ops_from_graph(self, graph, only_conv): assert isinstance(graph, GraphWrapper) ops = [] i = 0 @@ -33,22 +33,20 @@ class LatencyEvaluator(object): tmp = self._conv_op_args(op) elif op.type() in [ 'elementwise_add', 'elementwise_mul', 'elementwise_max' - ]: + ] and only_conv == False: tmp = self._eltwise_op_args(op) elif op.type() in [ 'relu', 'prelu', 'sigmoid', 'relu6', 'elu', 'brelu', 'leaky_relu' - ]: + ] and only_conv == False: tmp = self._activation_op_args(op) - elif op.type() == 'batch_norm': + elif op.type() == 'batch_norm' and only_conv == False: tmp = self._batch_norm_op_args(op) - elif op.type() == 'pool2d': + elif op.type() == 'pool2d' and only_conv == False: tmp = self._pooling_op_args(op) - elif op.type() == 'batch_norm': - tmp = self._batch_norm_op_args(op) - elif op.type() == 'softmax': + elif op.type() == 'softmax' and only_conv == False: tmp = self._softmax_op_args(op) - elif op.type() == 'mul': + elif op.type() == 'mul' and only_conv == False: tmp = self._fc_op_args(op) else: tmp = None @@ -268,11 +266,12 @@ class TableLatencyEvaluator(LatencyEvaluator): assert op_str in self._table return self._table[op_str] - def latency(self, graph): + def latency(self, graph, only_conv=True): """ Get latency of target graph. Args: - graph(GrapWrapper | Program): The graph to be evaluated. + - only_conv(bool): only evaluated convolution layer if `only_conv` is true. Default: True. Returns: latency(float): The latency of given graph on current evaluator. """ @@ -280,7 +279,7 @@ class TableLatencyEvaluator(LatencyEvaluator): if isinstance(graph, Program): graph = GraphWrapper(graph) assert isinstance(graph, GraphWrapper) - for op in self._get_ops_from_graph(graph): + for op in self._get_ops_from_graph(graph, only_conv): total_latency += self._op_latency( self._delimiter.join(map(lambda x: str(x), op))) return total_latency -- GitLab