未验证 提交 3ea12865 编写于 作者: W whs 提交者: GitHub

Fix FLOPs API (#11)

上级 e34a1505
......@@ -18,14 +18,22 @@ from ..core import GraphWrapper
__all__ = ["flops"]
def flops(program, detail=False):
def flops(program, only_conv=True, detail=False):
"""
Get FLOPS of target graph.
Args:
program(Program): The program used to calculate FLOPS.
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
default: True.
detail(bool): Whether to return detail of each convolution layer.
Return:
If `detail` is true, then return a tuple in format `(FLOPs, details)`, otherwise it will just return `FlOPs`
FLOPs(int): The FLOPs of target network.
details(dict): The key is the parameter name of convlution layer and the value is the FLOPs of each convolution layer.
"""
graph = GraphWrapper(program)
return _graph_flops(graph, detail=detail)
return _graph_flops(graph, only_conv=only_conv, detail=detail)
def _graph_flops(graph, only_conv=False, detail=False):
......@@ -44,7 +52,7 @@ def _graph_flops(graph, only_conv=False, detail=False):
with_bias = 1
else:
with_bias = 0
op_flops = 2 * h_out * w_out * c_out * (kernel_ops + with_bias)
op_flops = h_out * w_out * c_out * (kernel_ops + with_bias)
flops += op_flops
params2flops[op.inputs("Filter")[0].name()] = op_flops
elif op.type() == 'pool2d' and not only_conv:
......@@ -53,14 +61,17 @@ def _graph_flops(graph, only_conv=False, detail=False):
k_size = op.attr("ksize")
flops += h_out * w_out * c_out * (k_size[0]**2)
elif op.type() == 'mul' and not only_conv:
elif op.type() == 'mul':
x_shape = list(op.inputs("X")[0].shape())
y_shape = op.inputs("Y")[0].shape()
if x_shape[0] == -1:
x_shape[0] = 1
flops += 2 * x_shape[0] * x_shape[1] * y_shape[1]
elif op.type() in ['relu', 'sigmoid', 'batch_norm'] and not only_conv:
op_flops = x_shape[0] * x_shape[1] * y_shape[1]
flops += op_flops
params2flops[op.inputs("Y")[0].name()] = op_flops
elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'] and not only_conv:
input_shape = list(op.inputs("X")[0].shape())
if input_shape[0] == -1:
input_shape[0] = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册