提交 1c428507 编写于 作者: C ceci3

update

上级 c2f19ec1
......@@ -31,18 +31,15 @@ def dygraph_flops(model,
inputs: The inputs of the model, used to calculate FLOPs.
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
Default: True.
only_multiply(bool): Just return number of muliply in the model if `only_multiply` is true.
only_multiply(bool): If `only_multiply` is true, just return number of muliply in the model, the
multiply in such as conv, conv_transpose, norm and mul operators will be count.
Default: False.
detail(bool): Whether to return detail of each convolution layer. Default: False.
"""
_, program, _, _, _ = jit._trace(model, inputs)
graph = GraphWrapper(program)
return _graph_flops(
graph,
only_conv=only_conv,
only_multiply=only_multiply,
dygraph=True,
detail=detail)
graph, only_conv=only_conv, only_multiply=only_multiply, detail=detail)
def flops(program, only_conv=True, only_multiply=False, detail=False):
......@@ -66,11 +63,7 @@ def flops(program, only_conv=True, only_multiply=False, detail=False):
graph, only_conv=only_conv, only_multiply=only_multiply, detail=detail)
def _graph_flops(graph,
only_conv=True,
only_multiply=False,
dygraph=False,
detail=False):
def _graph_flops(graph, only_conv=True, only_multiply=False, detail=False):
assert isinstance(graph, GraphWrapper)
flops = 0
params2flops = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册