未验证 提交 e2a31d80 编写于 作者: S swli 提交者: GitHub

some fix in flops_profiler (#2068)

* bugs in profiler:
1. Tensor.bmm missed in _patch_tensor_methods function
2. missed funtions in _reload_functionals and _reload_tensor_methods functions
3. torch.mm and torch.Tensor.mm will have same __name__ in wrapFunc, my suggustion is use __str__ instead.

* formatting

---------
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: NCheng Li <pistasable@gmail.com>
上级 ef6a958e
......@@ -788,7 +788,7 @@ def _elementwise_flops_compute(input, other):
def wrapFunc(func, funcFlopCompute):
oldFunc = func
name = func.__name__
name = func.__str__
old_functions[name] = oldFunc
def newFunc(*args, **kwds):
......@@ -799,7 +799,7 @@ def wrapFunc(func, funcFlopCompute):
module_mac_count[-1].append((name, macs))
return oldFunc(*args, **kwds)
newFunc.__name__ = func.__name__
newFunc.__str__ = func.__str__
return newFunc
......@@ -865,7 +865,7 @@ def _patch_tensor_methods():
torch.mm = wrapFunc(torch.mm, _matmul_flops_compute)
torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute)
torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute)
torch.Tensor.bmm = wrapFunc(torch.bmm, _matmul_flops_compute)
torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute)
torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute)
torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute)
......@@ -878,42 +878,65 @@ def _patch_tensor_methods():
torch.einsum = wrapFunc(torch.einsum, _einsum_flops_compute)
torch.baddbmm = wrapFunc(torch.baddbmm, _tensor_addmm_flops_compute)
def _reload_functionals():
# torch.nn.functional does not support importlib.reload()
F.linear = old_functions[F.linear.__name__]
F.conv1d = old_functions[F.conv1d.__name__]
F.conv2d = old_functions[F.conv2d.__name__]
F.conv3d = old_functions[F.conv3d.__name__]
F.conv_transpose1d = old_functions[F.conv_transpose1d.__name__]
F.conv_transpose2d = old_functions[F.conv_transpose2d.__name__]
F.conv_transpose3d = old_functions[F.conv_transpose3d.__name__]
F.relu = old_functions[F.relu.__name__]
F.prelu = old_functions[F.prelu.__name__]
F.elu = old_functions[F.elu.__name__]
F.leaky_relu = old_functions[F.leaky_relu.__name__]
F.relu6 = old_functions[F.relu6.__name__]
F.batch_norm = old_functions[F.batch_norm.__name__]
F.avg_pool1d = old_functions[F.avg_pool1d.__name__]
F.avg_pool2d = old_functions[F.avg_pool2d.__name__]
F.avg_pool3d = old_functions[F.avg_pool3d.__name__]
F.max_pool1d = old_functions[F.max_pool1d.__name__]
F.max_pool2d = old_functions[F.max_pool2d.__name__]
F.max_pool3d = old_functions[F.max_pool3d.__name__]
F.adaptive_avg_pool1d = old_functions[F.adaptive_avg_pool1d.__name__]
F.adaptive_avg_pool2d = old_functions[F.adaptive_avg_pool2d.__name__]
F.adaptive_avg_pool3d = old_functions[F.adaptive_avg_pool3d.__name__]
F.adaptive_max_pool1d = old_functions[F.adaptive_max_pool1d.__name__]
F.adaptive_max_pool2d = old_functions[F.adaptive_max_pool2d.__name__]
F.adaptive_max_pool3d = old_functions[F.adaptive_max_pool3d.__name__]
F.upsample = old_functions[F.upsample.__name__]
F.interpolate = old_functions[F.interpolate.__name__]
F.softmax = old_functions[F.softmax.__name__]
F.embedding = old_functions[F.embedding.__name__]
F.linear = old_functions[F.linear.__str__]
F.conv1d = old_functions[F.conv1d.__str__]
F.conv2d = old_functions[F.conv2d.__str__]
F.conv3d = old_functions[F.conv3d.__str__]
F.conv_transpose1d = old_functions[F.conv_transpose1d.__str__]
F.conv_transpose2d = old_functions[F.conv_transpose2d.__str__]
F.conv_transpose3d = old_functions[F.conv_transpose3d.__str__]
F.relu = old_functions[F.relu.__str__]
F.prelu = old_functions[F.prelu.__str__]
F.elu = old_functions[F.elu.__str__]
F.leaky_relu = old_functions[F.leaky_relu.__str__]
F.relu6 = old_functions[F.relu6.__str__]
if hasattr(F, "silu"):
F.silu = old_functions[F.silu.__str__]
F.gelu = old_functions[F.gelu.__str__]
F.batch_norm = old_functions[F.batch_norm.__str__]
F.layer_norm = old_functions[F.layer_norm.__str__]
F.instance_norm = old_functions[F.instance_norm.__str__]
F.group_norm = old_functions[F.group_norm.__str__]
F.avg_pool1d = old_functions[F.avg_pool1d.__str__]
F.avg_pool2d = old_functions[F.avg_pool2d.__str__]
F.avg_pool3d = old_functions[F.avg_pool3d.__str__]
F.max_pool1d = old_functions[F.max_pool1d.__str__]
F.max_pool2d = old_functions[F.max_pool2d.__str__]
F.max_pool3d = old_functions[F.max_pool3d.__str__]
F.adaptive_avg_pool1d = old_functions[F.adaptive_avg_pool1d.__str__]
F.adaptive_avg_pool2d = old_functions[F.adaptive_avg_pool2d.__str__]
F.adaptive_avg_pool3d = old_functions[F.adaptive_avg_pool3d.__str__]
F.adaptive_max_pool1d = old_functions[F.adaptive_max_pool1d.__str__]
F.adaptive_max_pool2d = old_functions[F.adaptive_max_pool2d.__str__]
F.adaptive_max_pool3d = old_functions[F.adaptive_max_pool3d.__str__]
F.upsample = old_functions[F.upsample.__str__]
F.interpolate = old_functions[F.interpolate.__str__]
F.softmax = old_functions[F.softmax.__str__]
F.embedding = old_functions[F.embedding.__str__]
def _reload_tensor_methods():
torch.matmul = old_functions[torch.matmul.__name__]
torch.matmul = old_functions[torch.matmul.__str__]
torch.Tensor.matmul = old_functions[torch.Tensor.matmul.__str__]
torch.mm = old_functions[torch.mm.__str__]
torch.Tensor.mm = old_functions[torch.Tensor.mm.__str__]
torch.bmm = old_functions[torch.matmul.__str__]
torch.Tensor.bmm = old_functions[torch.Tensor.bmm.__str__]
torch.addmm = old_functions[torch.addmm.__str__]
torch.Tensor.addmm = old_functions[torch.Tensor.addmm.__str__]
torch.mul = old_functions[torch.mul.__str__]
torch.Tensor.mul = old_functions[torch.Tensor.mul.__str__]
torch.add = old_functions[torch.add.__str__]
torch.Tensor.add = old_functions[torch.Tensor.add.__str__]
torch.einsum = old_functions[torch.einsum.__str__]
torch.baddbmm = old_functions[torch.baddbmm.__str__]
def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册