提交 68fb998d 编写于 作者: C ceci3

add macs

上级 1c428507
......@@ -16,13 +16,12 @@ import numpy as np
import paddle.fluid.dygraph.jit as jit
from ..core import GraphWrapper
__all__ = ["dygraph_flops", "flops"]
__all__ = ["dygraph_flops", "flops", "dygrpah_macs", "macs"]
def dygraph_flops(model,
inputs,
only_conv=True,
only_multiply=False,
detail=False):
"""
Get FLOPs of dygraph model.
......@@ -31,26 +30,21 @@ def dygraph_flops(model,
inputs: The inputs of the model, used to calculate FLOPs.
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
Default: True.
only_multiply(bool): If `only_multiply` is true, just return number of muliply in the model, the
multiply in such as conv, conv_transpose, norm and mul operators will be count.
Default: False.
detail(bool): Whether to return detail of each convolution layer. Default: False.
"""
_, program, _, _, _ = jit._trace(model, inputs)
graph = GraphWrapper(program)
return _graph_flops(
graph, only_conv=only_conv, only_multiply=only_multiply, detail=detail)
return _graph_cals(
graph, only_conv=only_conv, count='FLOPs', detail=detail)
def flops(program, only_conv=True, only_multiply=False, detail=False):
def flops(program, only_conv=True, detail=False):
"""
Get FLOPS of target graph.
Args:
program(Program): The program used to calculate FLOPS.
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
default: True.
only_multiply(bool): Just return number of muliply in the model if `only_multiply` is true.
Default: False.
detail(bool): Whether to return detail of each convolution layer. Default: False.
Return:
......@@ -59,12 +53,50 @@ def flops(program, only_conv=True, only_multiply=False, detail=False):
details(dict): The key is the parameter name of convlution layer and the value is the FLOPs of each convolution layer.
"""
graph = GraphWrapper(program)
return _graph_cals(
graph, only_conv=only_conv, count='FLOPs', detail=detail)
def dygraph_macs(model,
inputs,
only_conv=False,
detail=False):
"""
Get FLOPs of dygraph model.
Args:
model: The dygraph model to calculate FLOPs.
inputs: The inputs of the model, used to calculate FLOPs.
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
Default: True.
detail(bool): Whether to return detail of each convolution layer. Default: False.
"""
_, program, _, _, _ = jit._trace(model, inputs)
graph = GraphWrapper(program)
return _graph_flops(
graph, only_conv=only_conv, only_multiply=only_multiply, detail=detail)
graph, only_conv=only_conv, count='MACs', detail=detail)
def _graph_flops(graph, only_conv=True, only_multiply=False, detail=False):
def macs(program, only_conv=True, detail=False):
"""
Get FLOPS of target graph.
Args:
program(Program): The program used to calculate FLOPS.
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
default: True.
detail(bool): Whether to return detail of each convolution layer. Default: False.
Return:
If `detail` is true, then return a tuple in format `(FLOPs, details)`, otherwise it will just return `FlOPs`
FLOPs(int): The FLOPs of target network.
details(dict): The key is the parameter name of convlution layer and the value is the FLOPs of each convolution layer.
"""
graph = GraphWrapper(program)
return _graph_cals(
graph, only_conv=only_conv, count='MACs', detail=detail)
def _graph_cals(graph, only_conv=True, count, detail=False):
assert isinstance(graph, GraphWrapper)
assert count.lower() in ['flops', 'macs'], "count {} not support now".format(count)
flops = 0
params2flops = {}
for op in graph.ops():
......@@ -77,32 +109,43 @@ def _graph_flops(graph, only_conv=True, only_multiply=False, detail=False):
c_out, c_in, k_h, k_w = filter_shape
_, _, h_out, w_out = output_shape
# c_in is the channel number of filter. It is (input_channel // groups).
kernel_ops = k_h * k_w * float(c_in)
### after dygraph model to static program, conv op donnot have Bias attrs
op_flops = h_out * w_out * c_out * kernel_ops
if count == 'MACs':
out_pixel = k_h * k_w * float(c_in)
else:
### count == 'FLOPs'
### add bias count in elementwise_add op
out_pixel = 2 * k_h * k_w * float(c_in) - 1
op_flops = h_out * w_out * c_out * out_pixel
flops += op_flops
params2flops[op.inputs("Filter")[0].name()] = op_flops
elif op.type() == 'pool2d' and not only_conv and not only_multiply:
output_shape = op.outputs("Out")[0].shape()
_, c_out, h_out, w_out = output_shape
elif op.type() == 'pool2d' and not only_conv:
if count == 'MACs':
op_shape = op.outputs("Out")[0].shape()
else:
op_shape = op.inputs("X")[0].shape()
_, c_out, h_out, w_out = op_shape
k_size = op.attr("ksize")
op_flops = h_out * w_out * c_out * (k_size[0]**2)
flops += op_flops
elif op.type() == 'mul':
elif op.type() == 'mul' and not only_conv:
x_shape = list(op.inputs("X")[0].shape())
y_shape = op.inputs("Y")[0].shape()
if x_shape[0] == -1:
x_shape[0] = 1
flops += x_shape[0] * x_shape[1] * y_shape[1]
op_flops = x_shape[0] * x_shape[1] * y_shape[1]
if count == 'MACs':
op_flops = x_shape[0] * x_shape[1] * y_shape[1]
else:
op_flops = (2 * x_shape[0] * x_shape[1] - 1) * y_shape[1]
flops += op_flops
params2flops[op.inputs("Y")[0].name()] = op_flops
elif op.type() in ['relu', 'sigmoid', 'relu6'
] and not only_conv and not only_multiply:
] and not only_conv:
input_shape = list(op.inputs("X")[0].shape())
if input_shape[0] == -1:
input_shape[0] = 1
......@@ -110,17 +153,21 @@ def _graph_flops(graph, only_conv=True, only_multiply=False, detail=False):
flops += op_flops
elif op.type() in ['batch_norm', 'instance_norm', 'layer_norm'
] and not only_conv and only_multiply:
] and not only_conv:
input_shape = list(op.inputs("X")[0].shape())
if input_shape[0] == -1:
input_shape[0] = 1
### (x - mean) * sqrt(var)
op_flops = np.product(input_shape)
if count == 'FLOPs':
### NOTE: if scale and bias can be none (Need to fix in bn op), it need to add more condition to determine if need to multiply 2
op_flops *= 2
flops += op_flops
elif op.type() in ['elementwise_add'] and not only_multiply:
elif op.type() in ['elementwise_add'] and count == 'FLOPs':
input_shape = list(op.inputs("X")[0].shape())
### if inputs Y is parameter that means add bias after conv
if op.inputs("Y")[0].is_parameter() or not only_conv:
### if inputs Y is parameter that means add bias after conv or norm
if op.inputs("Y")[0].is_parameter():
if input_shape[0] == -1:
input_shape[0] = 1
op_flops = np.product(input_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册