未验证 提交 e2a31d80 编写于 作者: S swli 提交者: GitHub

some fix in flops_profiler (#2068)

* bugs in profiler:
1. Tensor.bmm missed in _patch_tensor_methods function
2. missed funtions in _reload_functionals and _reload_tensor_methods functions
3. torch.mm and torch.Tensor.mm will have same __name__ in wrapFunc, my suggustion is use __str__ instead.

* formatting

---------
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: NCheng Li <pistasable@gmail.com>
上级 ef6a958e
...@@ -788,7 +788,7 @@ def _elementwise_flops_compute(input, other): ...@@ -788,7 +788,7 @@ def _elementwise_flops_compute(input, other):
def wrapFunc(func, funcFlopCompute): def wrapFunc(func, funcFlopCompute):
oldFunc = func oldFunc = func
name = func.__name__ name = func.__str__
old_functions[name] = oldFunc old_functions[name] = oldFunc
def newFunc(*args, **kwds): def newFunc(*args, **kwds):
...@@ -799,7 +799,7 @@ def wrapFunc(func, funcFlopCompute): ...@@ -799,7 +799,7 @@ def wrapFunc(func, funcFlopCompute):
module_mac_count[-1].append((name, macs)) module_mac_count[-1].append((name, macs))
return oldFunc(*args, **kwds) return oldFunc(*args, **kwds)
newFunc.__name__ = func.__name__ newFunc.__str__ = func.__str__
return newFunc return newFunc
...@@ -865,7 +865,7 @@ def _patch_tensor_methods(): ...@@ -865,7 +865,7 @@ def _patch_tensor_methods():
torch.mm = wrapFunc(torch.mm, _matmul_flops_compute) torch.mm = wrapFunc(torch.mm, _matmul_flops_compute)
torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute) torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute)
torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute) torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute)
torch.Tensor.bmm = wrapFunc(torch.bmm, _matmul_flops_compute) torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute)
torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute) torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute)
torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute) torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute)
...@@ -878,42 +878,65 @@ def _patch_tensor_methods(): ...@@ -878,42 +878,65 @@ def _patch_tensor_methods():
torch.einsum = wrapFunc(torch.einsum, _einsum_flops_compute) torch.einsum = wrapFunc(torch.einsum, _einsum_flops_compute)
torch.baddbmm = wrapFunc(torch.baddbmm, _tensor_addmm_flops_compute)
def _reload_functionals(): def _reload_functionals():
# torch.nn.functional does not support importlib.reload() # torch.nn.functional does not support importlib.reload()
F.linear = old_functions[F.linear.__name__] F.linear = old_functions[F.linear.__str__]
F.conv1d = old_functions[F.conv1d.__name__] F.conv1d = old_functions[F.conv1d.__str__]
F.conv2d = old_functions[F.conv2d.__name__] F.conv2d = old_functions[F.conv2d.__str__]
F.conv3d = old_functions[F.conv3d.__name__] F.conv3d = old_functions[F.conv3d.__str__]
F.conv_transpose1d = old_functions[F.conv_transpose1d.__name__] F.conv_transpose1d = old_functions[F.conv_transpose1d.__str__]
F.conv_transpose2d = old_functions[F.conv_transpose2d.__name__] F.conv_transpose2d = old_functions[F.conv_transpose2d.__str__]
F.conv_transpose3d = old_functions[F.conv_transpose3d.__name__] F.conv_transpose3d = old_functions[F.conv_transpose3d.__str__]
F.relu = old_functions[F.relu.__name__] F.relu = old_functions[F.relu.__str__]
F.prelu = old_functions[F.prelu.__name__] F.prelu = old_functions[F.prelu.__str__]
F.elu = old_functions[F.elu.__name__] F.elu = old_functions[F.elu.__str__]
F.leaky_relu = old_functions[F.leaky_relu.__name__] F.leaky_relu = old_functions[F.leaky_relu.__str__]
F.relu6 = old_functions[F.relu6.__name__] F.relu6 = old_functions[F.relu6.__str__]
F.batch_norm = old_functions[F.batch_norm.__name__] if hasattr(F, "silu"):
F.avg_pool1d = old_functions[F.avg_pool1d.__name__] F.silu = old_functions[F.silu.__str__]
F.avg_pool2d = old_functions[F.avg_pool2d.__name__] F.gelu = old_functions[F.gelu.__str__]
F.avg_pool3d = old_functions[F.avg_pool3d.__name__] F.batch_norm = old_functions[F.batch_norm.__str__]
F.max_pool1d = old_functions[F.max_pool1d.__name__] F.layer_norm = old_functions[F.layer_norm.__str__]
F.max_pool2d = old_functions[F.max_pool2d.__name__] F.instance_norm = old_functions[F.instance_norm.__str__]
F.max_pool3d = old_functions[F.max_pool3d.__name__] F.group_norm = old_functions[F.group_norm.__str__]
F.adaptive_avg_pool1d = old_functions[F.adaptive_avg_pool1d.__name__] F.avg_pool1d = old_functions[F.avg_pool1d.__str__]
F.adaptive_avg_pool2d = old_functions[F.adaptive_avg_pool2d.__name__] F.avg_pool2d = old_functions[F.avg_pool2d.__str__]
F.adaptive_avg_pool3d = old_functions[F.adaptive_avg_pool3d.__name__] F.avg_pool3d = old_functions[F.avg_pool3d.__str__]
F.adaptive_max_pool1d = old_functions[F.adaptive_max_pool1d.__name__] F.max_pool1d = old_functions[F.max_pool1d.__str__]
F.adaptive_max_pool2d = old_functions[F.adaptive_max_pool2d.__name__] F.max_pool2d = old_functions[F.max_pool2d.__str__]
F.adaptive_max_pool3d = old_functions[F.adaptive_max_pool3d.__name__] F.max_pool3d = old_functions[F.max_pool3d.__str__]
F.upsample = old_functions[F.upsample.__name__] F.adaptive_avg_pool1d = old_functions[F.adaptive_avg_pool1d.__str__]
F.interpolate = old_functions[F.interpolate.__name__] F.adaptive_avg_pool2d = old_functions[F.adaptive_avg_pool2d.__str__]
F.softmax = old_functions[F.softmax.__name__] F.adaptive_avg_pool3d = old_functions[F.adaptive_avg_pool3d.__str__]
F.embedding = old_functions[F.embedding.__name__] F.adaptive_max_pool1d = old_functions[F.adaptive_max_pool1d.__str__]
F.adaptive_max_pool2d = old_functions[F.adaptive_max_pool2d.__str__]
F.adaptive_max_pool3d = old_functions[F.adaptive_max_pool3d.__str__]
F.upsample = old_functions[F.upsample.__str__]
F.interpolate = old_functions[F.interpolate.__str__]
F.softmax = old_functions[F.softmax.__str__]
F.embedding = old_functions[F.embedding.__str__]
def _reload_tensor_methods(): def _reload_tensor_methods():
torch.matmul = old_functions[torch.matmul.__name__] torch.matmul = old_functions[torch.matmul.__str__]
torch.Tensor.matmul = old_functions[torch.Tensor.matmul.__str__]
torch.mm = old_functions[torch.mm.__str__]
torch.Tensor.mm = old_functions[torch.Tensor.mm.__str__]
torch.bmm = old_functions[torch.matmul.__str__]
torch.Tensor.bmm = old_functions[torch.Tensor.bmm.__str__]
torch.addmm = old_functions[torch.addmm.__str__]
torch.Tensor.addmm = old_functions[torch.Tensor.addmm.__str__]
torch.mul = old_functions[torch.mul.__str__]
torch.Tensor.mul = old_functions[torch.Tensor.mul.__str__]
torch.add = old_functions[torch.add.__str__]
torch.Tensor.add = old_functions[torch.Tensor.add.__str__]
torch.einsum = old_functions[torch.einsum.__str__]
torch.baddbmm = old_functions[torch.baddbmm.__str__]
def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册