未验证 提交 94bcc2ab 编写于 作者: 0 0x45f 提交者: GitHub

set net.forward to original forward function in flops (#36852)

上级 0666b858
...@@ -17,6 +17,7 @@ import warnings ...@@ -17,6 +17,7 @@ import warnings
import paddle.nn as nn import paddle.nn as nn
import numpy as np import numpy as np
from .static_flops import static_flops, Table from .static_flops import static_flops, Table
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
__all__ = [] __all__ = []
...@@ -100,6 +101,10 @@ def flops(net, input_size, custom_ops=None, print_detail=False): ...@@ -100,6 +101,10 @@ def flops(net, input_size, custom_ops=None, print_detail=False):
#Total Flops: 347560 Total Params: 61610 #Total Flops: 347560 Total Params: 61610
""" """
if isinstance(net, nn.Layer): if isinstance(net, nn.Layer):
# If net is a dy2stat model, net.forward is StaticFunction instance,
# we set net.forward to original forward function.
_, net.forward = unwrap_decorators(net.forward)
inputs = paddle.randn(input_size) inputs = paddle.randn(input_size)
return dynamic_flops( return dynamic_flops(
net, net,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册