未验证 提交 99fd9815 编写于 作者: Y yukavio 提交者: GitHub

fix flops api (#31081)

* remove PrettyTable dependence from paddle.flops

* fix bug in python2.7

* fix flops

* fix flops

* fix bug

* fix bug
上级 364cfa26
......@@ -121,7 +121,7 @@ def count_convNd(m, x, y):
bias_ops = 1 if m.bias is not None else 0
total_ops = int(y.numel()) * (
x.shape[1] / m._groups * kernel_ops + bias_ops)
m.total_ops += total_ops
m.total_ops += abs(int(total_ops))
def count_leaky_relu(m, x, y):
......@@ -135,15 +135,14 @@ def count_bn(m, x, y):
nelements = x.numel()
if not m.training:
total_ops = 2 * nelements
m.total_ops += int(total_ops)
m.total_ops += abs(int(total_ops))
def count_linear(m, x, y):
total_mul = m.weight.shape[0]
num_elements = y.numel()
total_ops = total_mul * num_elements
m.total_ops += int(total_ops)
m.total_ops += abs(int(total_ops))
def count_avgpool(m, x, y):
......@@ -161,8 +160,7 @@ def count_adap_avgpool(m, x, y):
kernel_ops = total_add + total_div
num_elements = y.numel()
total_ops = kernel_ops * num_elements
m.total_ops += int(total_ops)
m.total_ops += abs(int(total_ops))
def count_zero_ops(m, x, y):
......@@ -173,7 +171,7 @@ def count_parameters(m, x, y):
total_params = 0
for p in m.parameters():
total_params += p.numel()
m.total_params[0] = int(total_params)
m.total_params[0] = abs(int(total_params))
def count_io_info(m, x, y):
......
......@@ -127,6 +127,7 @@ def count_convNd(op):
bias_ops = 1 if len(op.inputs("Bias")) > 0 else 0
output_numel = np.product(op.outputs("Output")[0].shape()[1:])
total_ops = output_numel * (filter_ops + bias_ops)
total_ops = abs(total_ops)
return total_ops
......@@ -138,6 +139,7 @@ def count_leaky_relu(op):
def count_bn(op):
output_numel = np.product(op.outputs("Y")[0].shape()[1:])
total_ops = 2 * output_numel
total_ops = abs(total_ops)
return total_ops
......@@ -145,6 +147,7 @@ def count_linear(op):
total_mul = op.inputs("Y")[0].shape()[0]
numel = np.product(op.outputs("Out")[0].shape()[1:])
total_ops = total_mul * numel
total_ops = abs(total_ops)
return total_ops
......@@ -157,12 +160,14 @@ def count_pool2d(op):
kernel_ops = total_add + total_div
num_elements = np.product(output_shape[1:])
total_ops = kernel_ops * num_elements
total_ops = abs(total_ops)
return total_ops
def count_element_op(op):
input_shape = op.inputs("X")[0].shape()
total_ops = np.product(input_shape[1:])
total_ops = abs(total_ops)
return total_ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册