未验证 提交 9027f861 编写于 作者: A Aman Sanger 提交者: GitHub

Dont overwrite hook handles in flop profiler (#2106)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 16699d83
......@@ -74,8 +74,9 @@ class FlopsProfiler(object):
# if computing the flops of a module directly
if type(module) in MODULE_HOOK_MAPPING:
module.__flops_handle__ = module.register_forward_hook(
MODULE_HOOK_MAPPING[type(module)])
if not hasattr(module, "__flops_handle__"):
module.__flops_handle__ = module.register_forward_hook(
MODULE_HOOK_MAPPING[type(module)])
return
# if computing the flops of the functionals in a module
......@@ -83,7 +84,8 @@ class FlopsProfiler(object):
module_flop_count.append([])
module_mac_count.append([])
module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook)
if not hasattr(module, "__pre_hook_handle__"):
module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook)
def post_hook(module, input, output):
if module_flop_count:
......@@ -92,20 +94,24 @@ class FlopsProfiler(object):
module.__macs__ += sum([elem[1] for elem in module_mac_count[-1]])
module_mac_count.pop()
module.__post_hook_handle__ = module.register_forward_hook(post_hook)
if not hasattr(module, "__post_hook_handle__"):
module.__post_hook_handle__ = module.register_forward_hook(post_hook)
def start_time_hook(module, input):
torch.cuda.synchronize()
module.__start_time__ = time.time()
module.__start_time_hook_handle__ = module.register_forward_pre_hook(
start_time_hook)
if not hasattr(module, "__start_time_hook_handle"):
module.__start_time_hook_handle__ = module.register_forward_pre_hook(
start_time_hook)
def end_time_hook(module, input, output):
torch.cuda.synchronize()
module.__duration__ += time.time() - module.__start_time__
module.__end_time_hook_handle__ = module.register_forward_hook(end_time_hook)
if not hasattr(module, "__end_time_hook_handle__"):
module.__end_time_hook_handle__ = module.register_forward_hook(
end_time_hook)
self.model.apply(partial(register_module_hooks, ignore_list=ignore_list))
self.started = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册