未验证 提交 09f99c9e 编写于 作者: W whs 提交者: GitHub

Fix flops api. (#8)

上级 31513264
...@@ -35,13 +35,11 @@ def _graph_flops(graph, only_conv=False, detail=False): ...@@ -35,13 +35,11 @@ def _graph_flops(graph, only_conv=False, detail=False):
for op in graph.ops(): for op in graph.ops():
if op.type() in ['conv2d', 'depthwise_conv2d']: if op.type() in ['conv2d', 'depthwise_conv2d']:
filter_shape = op.inputs("Filter")[0].shape() filter_shape = op.inputs("Filter")[0].shape()
input_shape = op.inputs("Input")[0].shape()
output_shape = op.outputs("Output")[0].shape() output_shape = op.outputs("Output")[0].shape()
_, c_in, _, _ = input_shape c_out, c_in, k_h, k_w = filter_shape
c_out, _, k_h, k_w = filter_shape
_, _, h_out, w_out = output_shape _, _, h_out, w_out = output_shape
groups = op.attr("groups") # c_in is the channel number of filter. It is (input_channel // groups).
kernel_ops = k_h * k_w * (float(c_in) / groups) kernel_ops = k_h * k_w * float(c_in)
if len(op.inputs("Bias")) > 0: if len(op.inputs("Bias")) > 0:
with_bias = 1 with_bias = 1
else: else:
...@@ -50,7 +48,6 @@ def _graph_flops(graph, only_conv=False, detail=False): ...@@ -50,7 +48,6 @@ def _graph_flops(graph, only_conv=False, detail=False):
flops += op_flops flops += op_flops
params2flops[op.inputs("Filter")[0].name()] = op_flops params2flops[op.inputs("Filter")[0].name()] = op_flops
elif op.type() == 'pool2d' and not only_conv: elif op.type() == 'pool2d' and not only_conv:
input_shape = op.inputs("X")[0].shape()
output_shape = op.outputs("Out")[0].shape() output_shape = op.outputs("Out")[0].shape()
_, c_out, h_out, w_out = output_shape _, c_out, h_out, w_out = output_shape
k_size = op.attr("ksize") k_size = op.attr("ksize")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册