提交 887f2794 编写于 作者: W wanghaoshuang

Merge branch 'fix_flops' into 'develop'

Fix flops API.

See merge request !64
...@@ -18,33 +18,37 @@ from ..core import GraphWrapper ...@@ -18,33 +18,37 @@ from ..core import GraphWrapper
__all__ = ["flops"] __all__ = ["flops"]
def flops(program): def flops(program, detail=False):
""" """
Get FLOPS of target graph. Get FLOPS of target graph.
Args: Args:
program(Program): The program used to calculate FLOPS. program(Program): The program used to calculate FLOPS.
""" """
graph = GraphWrapper(program) graph = GraphWrapper(program)
return _graph_flops(graph) return _graph_flops(graph, detail=detail)
def _graph_flops(graph, only_conv=False): def _graph_flops(graph, only_conv=False, detail=False):
assert isinstance(graph, GraphWrapper) assert isinstance(graph, GraphWrapper)
flops = 0 flops = 0
params2flops = {}
for op in graph.ops(): for op in graph.ops():
if op.type() in ['conv2d', 'depthwise_conv2d']: if op.type() in ['conv2d', 'depthwise_conv2d']:
filter_shape = op.inputs("Filter")[0].shape() filter_shape = op.inputs("Filter")[0].shape()
input_shape = op.inputs("Input")[0].shape() input_shape = op.inputs("Input")[0].shape()
output_shape = op.outputs("Output")[0].shape() output_shape = op.outputs("Output")[0].shape()
c_out, c_in, k_h, k_w = filter_shape _, c_in, _, _ = input_shape
c_out, _, k_h, k_w = filter_shape
_, _, h_out, w_out = output_shape _, _, h_out, w_out = output_shape
groups = op.attr("groups") groups = op.attr("groups")
kernel_ops = k_h * k_w * (c_in / groups) kernel_ops = k_h * k_w * (float(c_in) / groups)
if len(op.inputs("Bias")) > 0: if len(op.inputs("Bias")) > 0:
with_bias = 1 with_bias = 1
else: else:
with_bias = 0 with_bias = 0
flops += 2 * h_out * w_out * c_out * (kernel_ops + with_bias) op_flops = 2 * h_out * w_out * c_out * (kernel_ops + with_bias)
flops += op_flops
params2flops[op.inputs("Filter")[0].name()] = op_flops
elif op.type() == 'pool2d' and not only_conv: elif op.type() == 'pool2d' and not only_conv:
input_shape = op.inputs("X")[0].shape() input_shape = op.inputs("X")[0].shape()
output_shape = op.outputs("Out")[0].shape() output_shape = op.outputs("Out")[0].shape()
...@@ -65,4 +69,7 @@ def _graph_flops(graph, only_conv=False): ...@@ -65,4 +69,7 @@ def _graph_flops(graph, only_conv=False):
input_shape[0] = 1 input_shape[0] = 1
flops += np.product(input_shape) flops += np.product(input_shape)
return flops if detail:
return flops, params2flops
else:
return flops
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册