diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 175788c9435ef9f76ef24f70ac5e43d94c2e853b..2ac061116f72579ef6d92d63453d6faf1fa13b2b 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -273,6 +273,7 @@ from . import onnx from .hapi import Model from .hapi import callbacks from .hapi import summary +from .hapi import flops import paddle.text import paddle.vision diff --git a/python/paddle/hapi/__init__.py b/python/paddle/hapi/__init__.py index 67965de5d97621e188acfa1e0384325b9ec5b7aa..0aea557a28c274398f3e5d6422eb2141778bf9ce 100644 --- a/python/paddle/hapi/__init__.py +++ b/python/paddle/hapi/__init__.py @@ -19,7 +19,9 @@ from . import model_summary from . import model from .model import * from .model_summary import summary +from .dynamic_flops import flops logger.setup_logger() __all__ = ['callbacks'] + model.__all__ + ['summary'] +__all__ = model.__all__ + ['flops'] diff --git a/python/paddle/hapi/dynamic_flops.py b/python/paddle/hapi/dynamic_flops.py index bd4679208ee93f4722d6b6757f656e312bb917e4..382227ea8329778d767a47114496d9365e6def8c 100644 --- a/python/paddle/hapi/dynamic_flops.py +++ b/python/paddle/hapi/dynamic_flops.py @@ -16,7 +16,7 @@ import paddle import warnings import paddle.nn as nn import numpy as np -from .static_flops import static_flops, _verify_dependent_package +from .static_flops import static_flops __all__ = ['flops'] @@ -264,7 +264,13 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): model.train() for handler in handler_collection: handler.remove() - _verify_dependent_package() + + try: + from prettytable import PrettyTable + except ImportError: + raise ImportError( + "paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. " + ) table = PrettyTable( ["Layer Name", "Input Shape", "Output Shape", "Params", "Flops"]) diff --git a/python/paddle/hapi/static_flops.py b/python/paddle/hapi/static_flops.py index e8870ab8f7e6ba300e1dbd4128731896fb2478ad..9815d4cfff54bab7672484a8d4501e8c59827f3c 100644 --- a/python/paddle/hapi/static_flops.py +++ b/python/paddle/hapi/static_flops.py @@ -166,22 +166,15 @@ def count_element_op(op): return total_ops -def _verify_dependent_package(): - """ - Verify whether `prettytable` is installed. - """ +def _graph_flops(graph, detail=False): + assert isinstance(graph, GraphWrapper) + flops = 0 try: from prettytable import PrettyTable except ImportError: raise ImportError( "paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. " ) - - -def _graph_flops(graph, detail=False): - assert isinstance(graph, GraphWrapper) - flops = 0 - _verify_dependent_package() table = PrettyTable(["OP Type", 'Param name', "Flops"]) for op in graph.ops(): param_name = '' diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index a410c726af18a736e7f036ce4b72d9924043b37a..af54b046fe699fa29cf6948f990a5cb9d44ddcda 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -33,6 +33,8 @@ from paddle.nn.layer.loss import CrossEntropyLoss from paddle.metric import Accuracy from paddle.vision.datasets import MNIST from paddle.vision.models import LeNet +import paddle.vision.models as models +import paddle.fluid.dygraph.jit as jit from paddle.io import DistributedBatchSampler, Dataset from paddle.hapi.model import prepare_distributed_context from paddle.fluid.dygraph.jit import declarative @@ -564,6 +566,24 @@ class TestModelFunction(unittest.TestCase): nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) paddle.summary(nlp_net, (1, 1, 2)) + def test_static_flops(self): + paddle.disable_static() + net = models.__dict__['mobilenet_v2'](pretrained=False) + inputs = paddle.randn([1, 3, 224, 224]) + static_program = jit._trace(net, inputs=[inputs])[1] + paddle.flops(static_program, [1, 3, 224, 224], print_detail=True) + + def test_dynamic_flops(self): + net = models.__dict__['mobilenet_v2'](pretrained=False) + + def customize_dropout(m, x, y): + m.total_ops += 0 + + paddle.flops( + net, [1, 3, 224, 224], + custom_ops={paddle.nn.Dropout: customize_dropout}, + print_detail=True) + def test_export_deploy_model(self): self.set_seed() np.random.seed(201)