未验证 提交 3ea12865 编写于 作者: W whs 提交者: GitHub

Fix FLOPs API (#11)

上级 e34a1505
...@@ -18,14 +18,22 @@ from ..core import GraphWrapper ...@@ -18,14 +18,22 @@ from ..core import GraphWrapper
__all__ = ["flops"] __all__ = ["flops"]
def flops(program, detail=False): def flops(program, only_conv=True, detail=False):
""" """
Get FLOPS of target graph. Get FLOPS of target graph.
Args: Args:
program(Program): The program used to calculate FLOPS. program(Program): The program used to calculate FLOPS.
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
default: True.
detail(bool): Whether to return detail of each convolution layer.
Return:
If `detail` is true, then return a tuple in format `(FLOPs, details)`, otherwise it will just return `FlOPs`
FLOPs(int): The FLOPs of target network.
details(dict): The key is the parameter name of convlution layer and the value is the FLOPs of each convolution layer.
""" """
graph = GraphWrapper(program) graph = GraphWrapper(program)
return _graph_flops(graph, detail=detail) return _graph_flops(graph, only_conv=only_conv, detail=detail)
def _graph_flops(graph, only_conv=False, detail=False): def _graph_flops(graph, only_conv=False, detail=False):
...@@ -44,7 +52,7 @@ def _graph_flops(graph, only_conv=False, detail=False): ...@@ -44,7 +52,7 @@ def _graph_flops(graph, only_conv=False, detail=False):
with_bias = 1 with_bias = 1
else: else:
with_bias = 0 with_bias = 0
op_flops = 2 * h_out * w_out * c_out * (kernel_ops + with_bias) op_flops = h_out * w_out * c_out * (kernel_ops + with_bias)
flops += op_flops flops += op_flops
params2flops[op.inputs("Filter")[0].name()] = op_flops params2flops[op.inputs("Filter")[0].name()] = op_flops
elif op.type() == 'pool2d' and not only_conv: elif op.type() == 'pool2d' and not only_conv:
...@@ -53,14 +61,17 @@ def _graph_flops(graph, only_conv=False, detail=False): ...@@ -53,14 +61,17 @@ def _graph_flops(graph, only_conv=False, detail=False):
k_size = op.attr("ksize") k_size = op.attr("ksize")
flops += h_out * w_out * c_out * (k_size[0]**2) flops += h_out * w_out * c_out * (k_size[0]**2)
elif op.type() == 'mul' and not only_conv: elif op.type() == 'mul':
x_shape = list(op.inputs("X")[0].shape()) x_shape = list(op.inputs("X")[0].shape())
y_shape = op.inputs("Y")[0].shape() y_shape = op.inputs("Y")[0].shape()
if x_shape[0] == -1: if x_shape[0] == -1:
x_shape[0] = 1 x_shape[0] = 1
flops += 2 * x_shape[0] * x_shape[1] * y_shape[1]
elif op.type() in ['relu', 'sigmoid', 'batch_norm'] and not only_conv: op_flops = x_shape[0] * x_shape[1] * y_shape[1]
flops += op_flops
params2flops[op.inputs("Y")[0].name()] = op_flops
elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'] and not only_conv:
input_shape = list(op.inputs("X")[0].shape()) input_shape = list(op.inputs("X")[0].shape())
if input_shape[0] == -1: if input_shape[0] == -1:
input_shape[0] = 1 input_shape[0] = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册