未验证 提交 99fd9815 编写于 作者: Y yukavio 提交者: GitHub

fix flops api (#31081)

* remove PrettyTable dependence from paddle.flops

* fix bug in python2.7

* fix flops

* fix flops

* fix bug

* fix bug
上级 364cfa26
...@@ -121,7 +121,7 @@ def count_convNd(m, x, y): ...@@ -121,7 +121,7 @@ def count_convNd(m, x, y):
bias_ops = 1 if m.bias is not None else 0 bias_ops = 1 if m.bias is not None else 0
total_ops = int(y.numel()) * ( total_ops = int(y.numel()) * (
x.shape[1] / m._groups * kernel_ops + bias_ops) x.shape[1] / m._groups * kernel_ops + bias_ops)
m.total_ops += total_ops m.total_ops += abs(int(total_ops))
def count_leaky_relu(m, x, y): def count_leaky_relu(m, x, y):
...@@ -135,15 +135,14 @@ def count_bn(m, x, y): ...@@ -135,15 +135,14 @@ def count_bn(m, x, y):
nelements = x.numel() nelements = x.numel()
if not m.training: if not m.training:
total_ops = 2 * nelements total_ops = 2 * nelements
m.total_ops += abs(int(total_ops))
m.total_ops += int(total_ops)
def count_linear(m, x, y): def count_linear(m, x, y):
total_mul = m.weight.shape[0] total_mul = m.weight.shape[0]
num_elements = y.numel() num_elements = y.numel()
total_ops = total_mul * num_elements total_ops = total_mul * num_elements
m.total_ops += int(total_ops) m.total_ops += abs(int(total_ops))
def count_avgpool(m, x, y): def count_avgpool(m, x, y):
...@@ -161,8 +160,7 @@ def count_adap_avgpool(m, x, y): ...@@ -161,8 +160,7 @@ def count_adap_avgpool(m, x, y):
kernel_ops = total_add + total_div kernel_ops = total_add + total_div
num_elements = y.numel() num_elements = y.numel()
total_ops = kernel_ops * num_elements total_ops = kernel_ops * num_elements
m.total_ops += abs(int(total_ops))
m.total_ops += int(total_ops)
def count_zero_ops(m, x, y): def count_zero_ops(m, x, y):
...@@ -173,7 +171,7 @@ def count_parameters(m, x, y): ...@@ -173,7 +171,7 @@ def count_parameters(m, x, y):
total_params = 0 total_params = 0
for p in m.parameters(): for p in m.parameters():
total_params += p.numel() total_params += p.numel()
m.total_params[0] = int(total_params) m.total_params[0] = abs(int(total_params))
def count_io_info(m, x, y): def count_io_info(m, x, y):
......
...@@ -127,6 +127,7 @@ def count_convNd(op): ...@@ -127,6 +127,7 @@ def count_convNd(op):
bias_ops = 1 if len(op.inputs("Bias")) > 0 else 0 bias_ops = 1 if len(op.inputs("Bias")) > 0 else 0
output_numel = np.product(op.outputs("Output")[0].shape()[1:]) output_numel = np.product(op.outputs("Output")[0].shape()[1:])
total_ops = output_numel * (filter_ops + bias_ops) total_ops = output_numel * (filter_ops + bias_ops)
total_ops = abs(total_ops)
return total_ops return total_ops
...@@ -138,6 +139,7 @@ def count_leaky_relu(op): ...@@ -138,6 +139,7 @@ def count_leaky_relu(op):
def count_bn(op): def count_bn(op):
output_numel = np.product(op.outputs("Y")[0].shape()[1:]) output_numel = np.product(op.outputs("Y")[0].shape()[1:])
total_ops = 2 * output_numel total_ops = 2 * output_numel
total_ops = abs(total_ops)
return total_ops return total_ops
...@@ -145,6 +147,7 @@ def count_linear(op): ...@@ -145,6 +147,7 @@ def count_linear(op):
total_mul = op.inputs("Y")[0].shape()[0] total_mul = op.inputs("Y")[0].shape()[0]
numel = np.product(op.outputs("Out")[0].shape()[1:]) numel = np.product(op.outputs("Out")[0].shape()[1:])
total_ops = total_mul * numel total_ops = total_mul * numel
total_ops = abs(total_ops)
return total_ops return total_ops
...@@ -157,12 +160,14 @@ def count_pool2d(op): ...@@ -157,12 +160,14 @@ def count_pool2d(op):
kernel_ops = total_add + total_div kernel_ops = total_add + total_div
num_elements = np.product(output_shape[1:]) num_elements = np.product(output_shape[1:])
total_ops = kernel_ops * num_elements total_ops = kernel_ops * num_elements
total_ops = abs(total_ops)
return total_ops return total_ops
def count_element_op(op): def count_element_op(op):
input_shape = op.inputs("X")[0].shape() input_shape = op.inputs("X")[0].shape()
total_ops = np.product(input_shape[1:]) total_ops = np.product(input_shape[1:])
total_ops = abs(total_ops)
return total_ops return total_ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册