未验证 提交 b5594759 编写于 作者: 0 0x45f 提交者: GitHub

set net.forward to original forward function in flops (#36852) (#37357)

set net.forward to original forward function in flops when net is a dy2stat model.
上级 5fd8312d
......@@ -17,6 +17,7 @@ import warnings
import paddle.nn as nn
import numpy as np
from .static_flops import static_flops, Table
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
__all__ = []
......@@ -100,6 +101,10 @@ def flops(net, input_size, custom_ops=None, print_detail=False):
#Total Flops: 347560 Total Params: 61610
"""
if isinstance(net, nn.Layer):
# If net is a dy2stat model, net.forward is StaticFunction instance,
# we set net.forward to original forward function.
_, net.forward = unwrap_decorators(net.forward)
inputs = paddle.randn(input_size)
return dynamic_flops(
net,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册