diff --git a/python/paddle/hapi/dynamic_flops.py b/python/paddle/hapi/dynamic_flops.py index 2c59ee67d4a8e9ae49bf7eadec2d1b1e9d1e7fce..5c48501c07d7eabf755958431816f56a5472d1e0 100644 --- a/python/paddle/hapi/dynamic_flops.py +++ b/python/paddle/hapi/dynamic_flops.py @@ -17,6 +17,7 @@ import warnings import paddle.nn as nn import numpy as np from .static_flops import static_flops, Table +from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators __all__ = [] @@ -100,6 +101,10 @@ def flops(net, input_size, custom_ops=None, print_detail=False): #Total Flops: 347560 Total Params: 61610 """ if isinstance(net, nn.Layer): + # If net is a dy2stat model, net.forward is StaticFunction instance, + # we set net.forward to original forward function. + _, net.forward = unwrap_decorators(net.forward) + inputs = paddle.randn(input_size) return dynamic_flops( net,