提交 ba02cfee 编写于 作者: W wanghaoshuang

Fix flops API.

上级 1cfa65ec
......@@ -18,33 +18,37 @@ from ..core import GraphWrapper
__all__ = ["flops"]
def flops(program):
def flops(program, detail=False):
"""
Get FLOPS of target graph.
Args:
program(Program): The program used to calculate FLOPS.
"""
graph = GraphWrapper(program)
return _graph_flops(graph)
return _graph_flops(graph, detail=detail)
def _graph_flops(graph, only_conv=False):
def _graph_flops(graph, only_conv=False, detail=False):
assert isinstance(graph, GraphWrapper)
flops = 0
params2flops = {}
for op in graph.ops():
if op.type() in ['conv2d', 'depthwise_conv2d']:
filter_shape = op.inputs("Filter")[0].shape()
input_shape = op.inputs("Input")[0].shape()
output_shape = op.outputs("Output")[0].shape()
c_out, c_in, k_h, k_w = filter_shape
_, c_in, _, _ = input_shape
c_out, _, k_h, k_w = filter_shape
_, _, h_out, w_out = output_shape
groups = op.attr("groups")
kernel_ops = k_h * k_w * (c_in / groups)
kernel_ops = k_h * k_w * (float(c_in) / groups)
if len(op.inputs("Bias")) > 0:
with_bias = 1
else:
with_bias = 0
flops += 2 * h_out * w_out * c_out * (kernel_ops + with_bias)
op_flops = 2 * h_out * w_out * c_out * (kernel_ops + with_bias)
flops += op_flops
params2flops[op.inputs("Filter")[0].name()] = op_flops
elif op.type() == 'pool2d' and not only_conv:
input_shape = op.inputs("X")[0].shape()
output_shape = op.outputs("Out")[0].shape()
......@@ -65,4 +69,7 @@ def _graph_flops(graph, only_conv=False):
input_shape[0] = 1
flops += np.product(input_shape)
return flops
if detail:
return flops, params2flops
else:
return flops
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册