diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index 859f1ac983246c414e7b302b985d01955045784f..a1a6a6fac9f10f43de1d5f73a1ba7de6afb59068 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -788,7 +788,7 @@ def _elementwise_flops_compute(input, other): def wrapFunc(func, funcFlopCompute): oldFunc = func - name = func.__name__ + name = func.__str__ old_functions[name] = oldFunc def newFunc(*args, **kwds): @@ -799,7 +799,7 @@ def wrapFunc(func, funcFlopCompute): module_mac_count[-1].append((name, macs)) return oldFunc(*args, **kwds) - newFunc.__name__ = func.__name__ + newFunc.__str__ = func.__str__ return newFunc @@ -865,7 +865,7 @@ def _patch_tensor_methods(): torch.mm = wrapFunc(torch.mm, _matmul_flops_compute) torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute) torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute) - torch.Tensor.bmm = wrapFunc(torch.bmm, _matmul_flops_compute) + torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute) torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute) torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute) @@ -878,42 +878,65 @@ def _patch_tensor_methods(): torch.einsum = wrapFunc(torch.einsum, _einsum_flops_compute) + torch.baddbmm = wrapFunc(torch.baddbmm, _tensor_addmm_flops_compute) + def _reload_functionals(): # torch.nn.functional does not support importlib.reload() - F.linear = old_functions[F.linear.__name__] - F.conv1d = old_functions[F.conv1d.__name__] - F.conv2d = old_functions[F.conv2d.__name__] - F.conv3d = old_functions[F.conv3d.__name__] - F.conv_transpose1d = old_functions[F.conv_transpose1d.__name__] - F.conv_transpose2d = old_functions[F.conv_transpose2d.__name__] - F.conv_transpose3d = old_functions[F.conv_transpose3d.__name__] - F.relu = old_functions[F.relu.__name__] - F.prelu = old_functions[F.prelu.__name__] - F.elu = old_functions[F.elu.__name__] - F.leaky_relu = old_functions[F.leaky_relu.__name__] - F.relu6 = old_functions[F.relu6.__name__] - F.batch_norm = old_functions[F.batch_norm.__name__] - F.avg_pool1d = old_functions[F.avg_pool1d.__name__] - F.avg_pool2d = old_functions[F.avg_pool2d.__name__] - F.avg_pool3d = old_functions[F.avg_pool3d.__name__] - F.max_pool1d = old_functions[F.max_pool1d.__name__] - F.max_pool2d = old_functions[F.max_pool2d.__name__] - F.max_pool3d = old_functions[F.max_pool3d.__name__] - F.adaptive_avg_pool1d = old_functions[F.adaptive_avg_pool1d.__name__] - F.adaptive_avg_pool2d = old_functions[F.adaptive_avg_pool2d.__name__] - F.adaptive_avg_pool3d = old_functions[F.adaptive_avg_pool3d.__name__] - F.adaptive_max_pool1d = old_functions[F.adaptive_max_pool1d.__name__] - F.adaptive_max_pool2d = old_functions[F.adaptive_max_pool2d.__name__] - F.adaptive_max_pool3d = old_functions[F.adaptive_max_pool3d.__name__] - F.upsample = old_functions[F.upsample.__name__] - F.interpolate = old_functions[F.interpolate.__name__] - F.softmax = old_functions[F.softmax.__name__] - F.embedding = old_functions[F.embedding.__name__] + F.linear = old_functions[F.linear.__str__] + F.conv1d = old_functions[F.conv1d.__str__] + F.conv2d = old_functions[F.conv2d.__str__] + F.conv3d = old_functions[F.conv3d.__str__] + F.conv_transpose1d = old_functions[F.conv_transpose1d.__str__] + F.conv_transpose2d = old_functions[F.conv_transpose2d.__str__] + F.conv_transpose3d = old_functions[F.conv_transpose3d.__str__] + F.relu = old_functions[F.relu.__str__] + F.prelu = old_functions[F.prelu.__str__] + F.elu = old_functions[F.elu.__str__] + F.leaky_relu = old_functions[F.leaky_relu.__str__] + F.relu6 = old_functions[F.relu6.__str__] + if hasattr(F, "silu"): + F.silu = old_functions[F.silu.__str__] + F.gelu = old_functions[F.gelu.__str__] + F.batch_norm = old_functions[F.batch_norm.__str__] + F.layer_norm = old_functions[F.layer_norm.__str__] + F.instance_norm = old_functions[F.instance_norm.__str__] + F.group_norm = old_functions[F.group_norm.__str__] + F.avg_pool1d = old_functions[F.avg_pool1d.__str__] + F.avg_pool2d = old_functions[F.avg_pool2d.__str__] + F.avg_pool3d = old_functions[F.avg_pool3d.__str__] + F.max_pool1d = old_functions[F.max_pool1d.__str__] + F.max_pool2d = old_functions[F.max_pool2d.__str__] + F.max_pool3d = old_functions[F.max_pool3d.__str__] + F.adaptive_avg_pool1d = old_functions[F.adaptive_avg_pool1d.__str__] + F.adaptive_avg_pool2d = old_functions[F.adaptive_avg_pool2d.__str__] + F.adaptive_avg_pool3d = old_functions[F.adaptive_avg_pool3d.__str__] + F.adaptive_max_pool1d = old_functions[F.adaptive_max_pool1d.__str__] + F.adaptive_max_pool2d = old_functions[F.adaptive_max_pool2d.__str__] + F.adaptive_max_pool3d = old_functions[F.adaptive_max_pool3d.__str__] + F.upsample = old_functions[F.upsample.__str__] + F.interpolate = old_functions[F.interpolate.__str__] + F.softmax = old_functions[F.softmax.__str__] + F.embedding = old_functions[F.embedding.__str__] def _reload_tensor_methods(): - torch.matmul = old_functions[torch.matmul.__name__] + torch.matmul = old_functions[torch.matmul.__str__] + torch.Tensor.matmul = old_functions[torch.Tensor.matmul.__str__] + torch.mm = old_functions[torch.mm.__str__] + torch.Tensor.mm = old_functions[torch.Tensor.mm.__str__] + torch.bmm = old_functions[torch.matmul.__str__] + torch.Tensor.bmm = old_functions[torch.Tensor.bmm.__str__] + torch.addmm = old_functions[torch.addmm.__str__] + torch.Tensor.addmm = old_functions[torch.Tensor.addmm.__str__] + torch.mul = old_functions[torch.mul.__str__] + torch.Tensor.mul = old_functions[torch.Tensor.mul.__str__] + torch.add = old_functions[torch.add.__str__] + torch.Tensor.add = old_functions[torch.Tensor.add.__str__] + + torch.einsum = old_functions[torch.einsum.__str__] + + torch.baddbmm = old_functions[torch.baddbmm.__str__] def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):